├── .gitignore ├── FreeDrag--diffusion--version ├── __pycache__ │ └── drag_pipeline.cpython-38.pyc ├── drag_pipeline.py ├── drag_ui.py ├── environment.yaml ├── local_pretrained_models │ └── dummy.txt ├── lora │ ├── lora_ckpt │ │ └── dummy.txt │ ├── train_dreambooth_lora.py │ └── train_lora.sh ├── lora_tmp │ └── pytorch_lora_weights.bin ├── readme.txt └── utils │ ├── __pycache__ │ ├── attn_utils.cpython-38.pyc │ ├── drag_utils.cpython-38.pyc │ ├── freedrag_utils.cpython-38.pyc │ ├── freeu_utils.cpython-38.pyc │ ├── lora_utils.cpython-310.pyc │ ├── lora_utils.cpython-38.pyc │ └── ui_utils.cpython-38.pyc │ ├── attn_utils.py │ ├── freedrag_utils.py │ ├── freeu_utils.py │ ├── lora_utils.py │ └── ui_utils.py ├── FreeDrag_gradio.py ├── README.md ├── arial.ttf ├── dnnlib ├── __init__.py └── util.py ├── download_models.sh ├── functions.py ├── legacy.py ├── requirements.txt ├── resources ├── Teaser.png ├── comparison_diffusion_1.png ├── comparison_diffusion_2.png ├── comparison_gan.png ├── fig1.png ├── logo2.png └── style.css ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py └── training ├── __init__.py ├── augment.py ├── dataset.py ├── loss.py └── networks.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/__pycache__/drag_pipeline.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/__pycache__/drag_pipeline.cpython-38.pyc -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/drag_ui.py: -------------------------------------------------------------------------------- 1 | # The diffusion-based implementation of FreeDrag is based on DragDiffusion(https://arxiv.org/abs/2306.14435) 2 | 3 | import os 4 | import gradio as gr 5 | 6 | from utils.ui_utils import get_points, undo_points 7 | from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag 8 | from utils.ui_utils import clear_all_gen, store_img_gen, gen_img, run_drag_gen 9 | 10 | LENGTH=480 # length of the square area displaying/editing images 11 | 12 | with gr.Blocks() as demo: 13 | # layout definition 14 | with gr.Row(): 15 | gr.Markdown(""" 16 | # Official Implementation of [FreeDrag--diffusion](https://arxiv.org/abs/2307.04684) 17 | """) 18 | 19 | # UI components for editing real images 20 | with gr.Tab(label="Editing Real Image"): 21 | mask = gr.State(value=None) # store mask 22 | selected_points = gr.State([]) # store points 23 | original_image = gr.State(value=None) # store original input image 24 | with gr.Row(): 25 | with gr.Column(): 26 | gr.Markdown("""

Draw Mask

""") 27 | canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask", 28 | show_label=True, height=LENGTH, width=LENGTH) # for mask painting 29 | train_lora_button = gr.Button("Train LoRA") 30 | with gr.Column(): 31 | gr.Markdown("""

Click Points

""") 32 | input_image = gr.Image(type="numpy", label="Click Points", 33 | show_label=True, height=LENGTH, width=LENGTH) # for points clicking 34 | undo_button = gr.Button("Undo point") 35 | with gr.Column(): 36 | gr.Markdown("""

Editing Results

""") 37 | output_image = gr.Image(type="numpy", label="Editing Results", 38 | show_label=True, height=LENGTH, width=LENGTH) 39 | with gr.Row(): 40 | run_button = gr.Button("Run") 41 | clear_all_button = gr.Button("Clear All") 42 | 43 | # general parameters 44 | with gr.Row(): 45 | prompt = gr.Textbox(label="Prompt") 46 | lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path") 47 | lora_status_bar = gr.Textbox(label="display LoRA training status") 48 | 49 | # algorithm specific parameters 50 | with gr.Tab("Drag Config"): 51 | with gr.Row(): 52 | max_step = gr.Number( 53 | value=300, 54 | label="max steps", 55 | info="Number of maximum dragging steps.", 56 | precision=0) 57 | lam = gr.Number(value=10, label="lam", info="regularization strength on unmasked areas") 58 | l_expected = gr.Number(value=1, label="l_expected", info="the expected initial loss for each dragging") 59 | d_max = gr.Number(value=5, label="d_max", info="the max motion distance for each dragging") 60 | # n_actual_inference_step = gr.Number(value=40, label="optimize latent step", precision=0) 61 | inversion_strength = gr.Slider(0, 1.0, 62 | value=0.7, 63 | label="inversion strength", 64 | info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") 65 | latent_lr = gr.Number(value=0.01, label="latent lr") 66 | start_step = gr.Number(value=0, label="start_step", precision=0, visible=False) 67 | start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False) 68 | 69 | with gr.Tab("Base Model Config"): 70 | with gr.Row(): 71 | local_models_dir = '/mnt/petrelfs/lingpengyang/DragDiffusion/local_pretrained_models' 72 | local_models_choice = \ 73 | [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] 74 | model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", 75 | label="Diffusion Model Path", 76 | choices=[ 77 | "runwayml/stable-diffusion-v1-5", 78 | ] + local_models_choice 79 | ) 80 | vae_path = gr.Dropdown(value="default", 81 | label="VAE choice", 82 | choices=["default", 83 | "stabilityai/sd-vae-ft-mse"] + local_models_choice 84 | ) 85 | 86 | with gr.Tab("LoRA Parameters"): 87 | with gr.Row(): 88 | lora_step = gr.Number(value=200, label="LoRA training steps", precision=0) 89 | lora_lr = gr.Number(value=0.0002, label="LoRA learning rate") 90 | lora_batch_size = gr.Number(value=4, label="LoRA batch size", precision=0) 91 | lora_rank = gr.Number(value=16, label="LoRA rank", precision=0) 92 | 93 | # UI components for editing generated images 94 | with gr.Tab(label="Editing Generated Image"): 95 | mask_gen = gr.State(value=None) # store mask 96 | selected_points_gen = gr.State([]) # store points 97 | original_image_gen = gr.State(value=None) # store the diffusion-generated image 98 | intermediate_latents_gen = gr.State(value=None) # store the intermediate diffusion latent during generation 99 | with gr.Row(): 100 | with gr.Column(): 101 | gr.Markdown("""

Draw Mask

""") 102 | canvas_gen = gr.Image(type="numpy", tool="sketch", label="Draw Mask", 103 | show_label=True, height=LENGTH, width=LENGTH) # for mask painting 104 | gen_img_button = gr.Button("Generate Image") 105 | with gr.Column(): 106 | gr.Markdown("""

Click Points

""") 107 | input_image_gen = gr.Image(type="numpy", label="Click Points", 108 | show_label=True, height=LENGTH, width=LENGTH) # for points clicking 109 | undo_button_gen = gr.Button("Undo point") 110 | with gr.Column(): 111 | gr.Markdown("""

Editing Results

""") 112 | output_image_gen = gr.Image(type="numpy", label="Editing Results", 113 | show_label=True, height=LENGTH, width=LENGTH) 114 | with gr.Row(): 115 | run_button_gen = gr.Button("Run") 116 | clear_all_button_gen = gr.Button("Clear All") 117 | 118 | # general parameters 119 | with gr.Row(): 120 | pos_prompt_gen = gr.Textbox(label="Positive Prompt") 121 | neg_prompt_gen = gr.Textbox(label="Negative Prompt") 122 | 123 | with gr.Tab("Generation Config"): 124 | with gr.Row(): 125 | local_models_dir = '/mnt/petrelfs/lingpengyang/DragDiffusion/local_pretrained_models' 126 | local_models_choice = \ 127 | [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] 128 | model_path_gen = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", 129 | label="Diffusion Model Path", 130 | choices=[ 131 | "runwayml/stable-diffusion-v1-5", 132 | "gsdf/Counterfeit-V2.5", 133 | "emilianJR/majicMIX_realistic", 134 | "SG161222/Realistic_Vision_V2.0", 135 | "stablediffusionapi/interiordesignsuperm", 136 | "stablediffusionapi/dvarch", 137 | ] + local_models_choice 138 | ) 139 | vae_path_gen = gr.Dropdown(value="default", 140 | label="VAE choice", 141 | choices=["default", 142 | "stabilityai/sd-vae-ft-mse"] + local_models_choice 143 | ) 144 | lora_path_gen = gr.Textbox(value="", label="LoRA path") 145 | gen_seed = gr.Number(value=65536, label="Generation Seed", precision=0) 146 | height = gr.Number(value=512, label="Height", precision=0) 147 | width = gr.Number(value=512, label="Width", precision=0) 148 | guidance_scale = gr.Number(value=7.5, label="CFG Scale") 149 | scheduler_name_gen = gr.Dropdown( 150 | value="DDIM", 151 | label="Scheduler", 152 | choices=[ 153 | "DDIM", 154 | "DPM++2M", 155 | "DPM++2M_karras" 156 | ] 157 | ) 158 | n_inference_step_gen = gr.Number(value=50, label="Total Sampling Steps", precision=0) 159 | 160 | with gr.Tab("FreeU Parameters"): 161 | with gr.Row(): 162 | b1_gen = gr.Slider(label='b1', 163 | info='1st stage backbone factor', 164 | minimum=1, 165 | maximum=1.6, 166 | step=0.05, 167 | value=1.1) 168 | b2_gen = gr.Slider(label='b2', 169 | info='2nd stage backbone factor', 170 | minimum=1, 171 | maximum=1.6, 172 | step=0.05, 173 | value=1.1) 174 | s1_gen = gr.Slider(label='s1', 175 | info='1st stage skip factor', 176 | minimum=0, 177 | maximum=1, 178 | step=0.05, 179 | value=0.8) 180 | s2_gen = gr.Slider(label='s2', 181 | info='2nd stage skip factor', 182 | minimum=0, 183 | maximum=1, 184 | step=0.05, 185 | value=0.8) 186 | 187 | with gr.Tab(label="Drag Config"): 188 | with gr.Row(): 189 | max_step_gen = gr.Number( 190 | value=300, 191 | label="Number of Max Dragging step", 192 | info="Number of Max Dragging step.", 193 | precision=0) 194 | lam_gen = gr.Number(value=10, label="lam", info="regularization strength on unmasked areas") 195 | l_expected_gen = gr.Number(value=1, label="l_expected", info="the expected initial loss for each dragging") 196 | d_max_gen = gr.Number(value=5, label="d_max", info="the max motion distance for each dragging") 197 | 198 | # n_actual_inference_step_gen = gr.Number(value=40, label="optimize latent step", precision=0) 199 | inversion_strength_gen = gr.Slider(0, 1.0, 200 | value=0.7, 201 | label="Inversion Strength", 202 | info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") 203 | latent_lr_gen = gr.Number(value=0.01, label="latent lr") 204 | start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False) 205 | start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False) 206 | 207 | # event definition 208 | # event for dragging user-input real image 209 | canvas.edit( 210 | store_img, 211 | [canvas], 212 | [original_image, selected_points, input_image, mask] 213 | ) 214 | input_image.select( 215 | get_points, 216 | [input_image, selected_points], 217 | [input_image], 218 | ) 219 | undo_button.click( 220 | undo_points, 221 | [original_image, mask], 222 | [input_image, selected_points] 223 | ) 224 | train_lora_button.click( 225 | train_lora_interface, 226 | [original_image, 227 | prompt, 228 | model_path, 229 | vae_path, 230 | lora_path, 231 | lora_step, 232 | lora_lr, 233 | lora_batch_size, 234 | lora_rank], 235 | [lora_status_bar] 236 | ) 237 | run_button.click( 238 | run_drag, 239 | [original_image, 240 | input_image, 241 | mask, 242 | prompt, 243 | selected_points, 244 | inversion_strength, 245 | lam, 246 | l_expected, 247 | d_max, 248 | latent_lr, 249 | max_step, 250 | model_path, 251 | vae_path, 252 | lora_path, 253 | start_step, 254 | start_layer, 255 | ], 256 | [output_image] 257 | ) 258 | clear_all_button.click( 259 | clear_all, 260 | [gr.Number(value=LENGTH, visible=False, precision=0)], 261 | [canvas, 262 | input_image, 263 | output_image, 264 | selected_points, 265 | original_image, 266 | mask] 267 | ) 268 | 269 | # event for dragging generated image 270 | canvas_gen.edit( 271 | store_img_gen, 272 | [canvas_gen], 273 | [original_image_gen, selected_points_gen, input_image_gen, mask_gen] 274 | ) 275 | input_image_gen.select( 276 | get_points, 277 | [input_image_gen, selected_points_gen], 278 | [input_image_gen], 279 | ) 280 | gen_img_button.click( 281 | gen_img, 282 | [ 283 | gr.Number(value=LENGTH, visible=False, precision=0), 284 | height, 285 | width, 286 | n_inference_step_gen, 287 | scheduler_name_gen, 288 | gen_seed, 289 | guidance_scale, 290 | pos_prompt_gen, 291 | neg_prompt_gen, 292 | model_path_gen, 293 | vae_path_gen, 294 | lora_path_gen, 295 | b1_gen, 296 | b2_gen, 297 | s1_gen, 298 | s2_gen, 299 | ], 300 | [canvas_gen, input_image_gen, output_image_gen, mask_gen, intermediate_latents_gen] 301 | ) 302 | undo_button_gen.click( 303 | undo_points, 304 | [original_image_gen, mask_gen], 305 | [input_image_gen, selected_points_gen] 306 | ) 307 | run_button_gen.click( 308 | run_drag_gen, 309 | [ 310 | n_inference_step_gen, 311 | scheduler_name_gen, 312 | original_image_gen, # the original image generated by the diffusion model 313 | input_image_gen, # image with clicking, masking, etc. 314 | intermediate_latents_gen, 315 | guidance_scale, 316 | mask_gen, 317 | pos_prompt_gen, 318 | neg_prompt_gen, 319 | selected_points_gen, 320 | inversion_strength_gen, 321 | lam_gen, 322 | latent_lr_gen, 323 | max_step_gen, 324 | l_expected_gen, 325 | d_max_gen, 326 | model_path_gen, 327 | vae_path_gen, 328 | lora_path_gen, 329 | start_step_gen, 330 | start_layer_gen, 331 | b1_gen, 332 | b2_gen, 333 | s1_gen, 334 | s2_gen, 335 | ], 336 | [output_image_gen] 337 | ) 338 | clear_all_button_gen.click( 339 | clear_all_gen, 340 | [gr.Number(value=LENGTH, visible=False, precision=0)], 341 | [canvas_gen, 342 | input_image_gen, 343 | output_image_gen, 344 | selected_points_gen, 345 | original_image_gen, 346 | mask_gen, 347 | intermediate_latents_gen, 348 | ] 349 | ) 350 | 351 | demo.queue().launch(share=True, debug=True,server_name="0.0.0.0",server_port=15322) -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/environment.yaml: -------------------------------------------------------------------------------- 1 | name: freedragdif 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=22.3.1 8 | - cudatoolkit=11.7 9 | - pip: 10 | - torch==2.0.0 11 | - torchvision==0.15.1 12 | - gradio==3.41.1 13 | - pydantic==2.0.2 14 | - albumentations==1.3.0 15 | - opencv-contrib-python==4.3.0.36 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.5.0 19 | - omegaconf==2.3.0 20 | - test-tube>=0.7.5 21 | - streamlit==1.12.1 22 | - einops==0.6.0 23 | - transformers==4.27.0 24 | - webdataset==0.2.5 25 | - kornia==0.6 26 | - open_clip_torch==2.16.0 27 | - invisible-watermark>=0.1.5 28 | - streamlit-drawable-canvas==0.8.0 29 | - torchmetrics==0.6.0 30 | - timm==0.6.12 31 | - addict==2.4.0 32 | - yapf==0.32.0 33 | - prettytable==3.6.0 34 | - safetensors==0.2.7 35 | - basicsr==1.4.2 36 | - accelerate==0.17.0 37 | - decord==0.6.0 38 | - diffusers==0.17.1 39 | - moviepy==1.0.3 40 | - opencv_python==4.7.0.68 41 | - Pillow==9.4.0 42 | - scikit_image==0.19.3 43 | - scipy==1.10.1 44 | - tensorboardX==2.6 45 | - tqdm==4.64.1 46 | - numpy==1.24.1 47 | -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/local_pretrained_models/dummy.txt: -------------------------------------------------------------------------------- 1 | You may put your pretrained model here. -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/lora/lora_ckpt/dummy.txt: -------------------------------------------------------------------------------- 1 | lora checkpoints will be saved in this folder 2 | -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/lora/train_lora.sh: -------------------------------------------------------------------------------- 1 | export SAMPLE_DIR="lora/samples/sculpture" 2 | export OUTPUT_DIR="lora/lora_ckpt/sculpture_lora" 3 | 4 | export MODEL_NAME="runwayml/stable-diffusion-v1-5" 5 | export LORA_RANK=16 6 | 7 | accelerate launch lora/train_dreambooth_lora.py \ 8 | --pretrained_model_name_or_path=$MODEL_NAME \ 9 | --instance_data_dir=$SAMPLE_DIR \ 10 | --output_dir=$OUTPUT_DIR \ 11 | --instance_prompt="a photo of a sculpture" \ 12 | --resolution=512 \ 13 | --train_batch_size=1 \ 14 | --gradient_accumulation_steps=1 \ 15 | --checkpointing_steps=100 \ 16 | --learning_rate=2e-4 \ 17 | --lr_scheduler="constant" \ 18 | --lr_warmup_steps=0 \ 19 | --max_train_steps=200 \ 20 | --lora_rank=$LORA_RANK \ 21 | --seed="0" 22 | -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/lora_tmp/pytorch_lora_weights.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/lora_tmp/pytorch_lora_weights.bin -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/readme.txt: -------------------------------------------------------------------------------- 1 | The diffusion-based FreeDrag is based on DragDiffusion(https://arxiv.org/abs/2306.14435) with permission. 2 | -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/utils/__pycache__/attn_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/attn_utils.cpython-38.pyc -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/utils/__pycache__/drag_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/drag_utils.cpython-38.pyc -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/utils/__pycache__/freedrag_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/freedrag_utils.cpython-38.pyc -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/utils/__pycache__/freeu_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/freeu_utils.cpython-38.pyc -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/utils/__pycache__/lora_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/lora_utils.cpython-310.pyc -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/utils/__pycache__/lora_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/lora_utils.cpython-38.pyc -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/utils/__pycache__/ui_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/ui_utils.cpython-38.pyc -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/utils/attn_utils.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from einops import rearrange, repeat 12 | 13 | 14 | class AttentionBase: 15 | 16 | def __init__(self): 17 | self.cur_step = 0 18 | self.num_att_layers = -1 19 | self.cur_att_layer = 0 20 | 21 | def after_step(self): 22 | pass 23 | 24 | def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): 25 | out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs) 26 | self.cur_att_layer += 1 27 | if self.cur_att_layer == self.num_att_layers: 28 | self.cur_att_layer = 0 29 | self.cur_step += 1 30 | # after step 31 | self.after_step() 32 | return out 33 | 34 | def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): 35 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) 36 | out = rearrange(out, 'b h n d -> b n (h d)') 37 | return out 38 | 39 | def reset(self): 40 | self.cur_step = 0 41 | self.cur_att_layer = 0 42 | 43 | 44 | class MutualSelfAttentionControl(AttentionBase): 45 | 46 | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5): 47 | """ 48 | Mutual self-attention control for Stable-Diffusion model 49 | Args: 50 | start_step: the step to start mutual self-attention control 51 | start_layer: the layer to start mutual self-attention control 52 | layer_idx: list of the layers to apply mutual self-attention control 53 | step_idx: list the steps to apply mutual self-attention control 54 | total_steps: the total number of steps 55 | """ 56 | super().__init__() 57 | self.total_steps = total_steps 58 | self.start_step = start_step 59 | self.start_layer = start_layer 60 | self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16)) 61 | self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) 62 | # store the guidance scale to decide whether there are unconditional branch 63 | self.guidance_scale = guidance_scale 64 | print("step_idx: ", self.step_idx) 65 | print("layer_idx: ", self.layer_idx) 66 | 67 | def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): 68 | """ 69 | Attention forward function 70 | """ 71 | if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: 72 | return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs) 73 | 74 | if self.guidance_scale > 1.0: 75 | qu, qc = q[0:2], q[2:4] 76 | ku, kc = k[0:2], k[2:4] 77 | vu, vc = v[0:2], v[2:4] 78 | 79 | # merge queries of source and target branch into one so we can use torch API 80 | qu = torch.cat([qu[0:1], qu[1:2]], dim=2) 81 | qc = torch.cat([qc[0:1], qc[1:2]], dim=2) 82 | 83 | out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) 84 | out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch 85 | out_u = rearrange(out_u, 'b h n d -> b n (h d)') 86 | 87 | out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) 88 | out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch 89 | out_c = rearrange(out_c, 'b h n d -> b n (h d)') 90 | 91 | out = torch.cat([out_u, out_c], dim=0) 92 | else: 93 | q = torch.cat([q[0:1], q[1:2]], dim=2) 94 | out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) 95 | out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch 96 | out = rearrange(out, 'b h n d -> b n (h d)') 97 | return out 98 | 99 | # forward function for default attention processor 100 | # modified from __call__ function of AttnProcessor in diffusers 101 | def override_attn_proc_forward(attn, editor, place_in_unet): 102 | def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): 103 | """ 104 | The attention is similar to the original implementation of LDM CrossAttention class 105 | except adding some modifications on the attention 106 | """ 107 | if encoder_hidden_states is not None: 108 | context = encoder_hidden_states 109 | if attention_mask is not None: 110 | mask = attention_mask 111 | 112 | to_out = attn.to_out 113 | if isinstance(to_out, nn.modules.container.ModuleList): 114 | to_out = attn.to_out[0] 115 | else: 116 | to_out = attn.to_out 117 | 118 | h = attn.heads 119 | q = attn.to_q(x) 120 | is_cross = context is not None 121 | context = context if is_cross else x 122 | k = attn.to_k(context) 123 | v = attn.to_v(context) 124 | 125 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) 126 | 127 | # the only difference 128 | out = editor( 129 | q, k, v, is_cross, place_in_unet, 130 | attn.heads, scale=attn.scale) 131 | 132 | return to_out(out) 133 | 134 | return forward 135 | 136 | # forward function for lora attention processor 137 | # modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1 138 | def override_lora_attn_proc_forward(attn, editor, place_in_unet): 139 | def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora_scale=1.0): 140 | residual = hidden_states 141 | input_ndim = hidden_states.ndim 142 | is_cross = encoder_hidden_states is not None 143 | 144 | if input_ndim == 4: 145 | batch_size, channel, height, width = hidden_states.shape 146 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 147 | 148 | batch_size, sequence_length, _ = ( 149 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 150 | ) 151 | 152 | if attention_mask is not None: 153 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 154 | # scaled_dot_product_attention expects attention_mask shape to be 155 | # (batch, heads, source_length, target_length) 156 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 157 | 158 | if attn.group_norm is not None: 159 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 160 | 161 | query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states) 162 | 163 | if encoder_hidden_states is None: 164 | encoder_hidden_states = hidden_states 165 | elif attn.norm_cross: 166 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 167 | 168 | key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states) 169 | value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states) 170 | 171 | query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value)) 172 | 173 | # the only difference 174 | hidden_states = editor( 175 | query, key, value, is_cross, place_in_unet, 176 | attn.heads, scale=attn.scale) 177 | 178 | # linear proj 179 | hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states) 180 | # dropout 181 | hidden_states = attn.to_out[1](hidden_states) 182 | 183 | if input_ndim == 4: 184 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 185 | 186 | if attn.residual_connection: 187 | hidden_states = hidden_states + residual 188 | 189 | hidden_states = hidden_states / attn.rescale_output_factor 190 | 191 | return hidden_states 192 | 193 | return forward 194 | 195 | def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'): 196 | """ 197 | Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt] 198 | """ 199 | def register_editor(net, count, place_in_unet): 200 | for name, subnet in net.named_children(): 201 | if net.__class__.__name__ == 'Attention': # spatial Transformer layer 202 | if attn_processor == 'attn_proc': 203 | net.forward = override_attn_proc_forward(net, editor, place_in_unet) 204 | elif attn_processor == 'lora_attn_proc': 205 | net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet) 206 | else: 207 | raise NotImplementedError("not implemented") 208 | return count + 1 209 | elif hasattr(net, 'children'): 210 | count = register_editor(subnet, count, place_in_unet) 211 | return count 212 | 213 | cross_att_count = 0 214 | for net_name, net in model.unet.named_children(): 215 | if "down" in net_name: 216 | cross_att_count += register_editor(net, 0, "down") 217 | elif "mid" in net_name: 218 | cross_att_count += register_editor(net, 0, "mid") 219 | elif "up" in net_name: 220 | cross_att_count += register_editor(net, 0, "up") 221 | editor.num_att_layers = cross_att_count 222 | -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/utils/freeu_utils.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | import torch 8 | import torch.fft as fft 9 | from diffusers.models.unet_2d_condition import logger 10 | from diffusers.utils import is_torch_version 11 | from typing import Any, Dict, List, Optional, Tuple, Union 12 | 13 | 14 | def isinstance_str(x: object, cls_name: str): 15 | """ 16 | Checks whether x has any class *named* cls_name in its ancestry. 17 | Doesn't require access to the class's implementation. 18 | 19 | Useful for patching! 20 | """ 21 | 22 | for _cls in x.__class__.__mro__: 23 | if _cls.__name__ == cls_name: 24 | return True 25 | 26 | return False 27 | 28 | 29 | def Fourier_filter(x, threshold, scale): 30 | dtype = x.dtype 31 | x = x.type(torch.float32) 32 | # FFT 33 | x_freq = fft.fftn(x, dim=(-2, -1)) 34 | x_freq = fft.fftshift(x_freq, dim=(-2, -1)) 35 | 36 | B, C, H, W = x_freq.shape 37 | mask = torch.ones((B, C, H, W)).cuda() 38 | 39 | crow, ccol = H // 2, W //2 40 | mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale 41 | x_freq = x_freq * mask 42 | 43 | # IFFT 44 | x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) 45 | x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real 46 | 47 | x_filtered = x_filtered.type(dtype) 48 | return x_filtered 49 | 50 | 51 | def register_upblock2d(model): 52 | def up_forward(self): 53 | def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): 54 | for resnet in self.resnets: 55 | # pop res hidden states 56 | res_hidden_states = res_hidden_states_tuple[-1] 57 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 58 | #print(f"in upblock2d, hidden states shape: {hidden_states.shape}") 59 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 60 | 61 | if self.training and self.gradient_checkpointing: 62 | 63 | def create_custom_forward(module): 64 | def custom_forward(*inputs): 65 | return module(*inputs) 66 | 67 | return custom_forward 68 | 69 | if is_torch_version(">=", "1.11.0"): 70 | hidden_states = torch.utils.checkpoint.checkpoint( 71 | create_custom_forward(resnet), hidden_states, temb, use_reentrant=False 72 | ) 73 | else: 74 | hidden_states = torch.utils.checkpoint.checkpoint( 75 | create_custom_forward(resnet), hidden_states, temb 76 | ) 77 | else: 78 | hidden_states = resnet(hidden_states, temb) 79 | 80 | if self.upsamplers is not None: 81 | for upsampler in self.upsamplers: 82 | hidden_states = upsampler(hidden_states, upsample_size) 83 | 84 | return hidden_states 85 | 86 | return forward 87 | 88 | for i, upsample_block in enumerate(model.unet.up_blocks): 89 | if isinstance_str(upsample_block, "UpBlock2D"): 90 | upsample_block.forward = up_forward(upsample_block) 91 | 92 | 93 | def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): 94 | def up_forward(self): 95 | def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): 96 | for resnet in self.resnets: 97 | # pop res hidden states 98 | res_hidden_states = res_hidden_states_tuple[-1] 99 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 100 | #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}") 101 | 102 | # --------------- FreeU code ----------------------- 103 | # Only operate on the first two stages 104 | if hidden_states.shape[1] == 1280: 105 | hidden_states[:,:640] = hidden_states[:,:640] * self.b1 106 | res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1) 107 | if hidden_states.shape[1] == 640: 108 | hidden_states[:,:320] = hidden_states[:,:320] * self.b2 109 | res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2) 110 | # --------------------------------------------------------- 111 | 112 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 113 | 114 | if self.training and self.gradient_checkpointing: 115 | 116 | def create_custom_forward(module): 117 | def custom_forward(*inputs): 118 | return module(*inputs) 119 | 120 | return custom_forward 121 | 122 | if is_torch_version(">=", "1.11.0"): 123 | hidden_states = torch.utils.checkpoint.checkpoint( 124 | create_custom_forward(resnet), hidden_states, temb, use_reentrant=False 125 | ) 126 | else: 127 | hidden_states = torch.utils.checkpoint.checkpoint( 128 | create_custom_forward(resnet), hidden_states, temb 129 | ) 130 | else: 131 | hidden_states = resnet(hidden_states, temb) 132 | 133 | if self.upsamplers is not None: 134 | for upsampler in self.upsamplers: 135 | hidden_states = upsampler(hidden_states, upsample_size) 136 | 137 | return hidden_states 138 | 139 | return forward 140 | 141 | for i, upsample_block in enumerate(model.unet.up_blocks): 142 | if isinstance_str(upsample_block, "UpBlock2D"): 143 | upsample_block.forward = up_forward(upsample_block) 144 | setattr(upsample_block, 'b1', b1) 145 | setattr(upsample_block, 'b2', b2) 146 | setattr(upsample_block, 's1', s1) 147 | setattr(upsample_block, 's2', s2) 148 | 149 | 150 | def register_crossattn_upblock2d(model): 151 | def up_forward(self): 152 | def forward( 153 | hidden_states: torch.FloatTensor, 154 | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], 155 | temb: Optional[torch.FloatTensor] = None, 156 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 157 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 158 | upsample_size: Optional[int] = None, 159 | attention_mask: Optional[torch.FloatTensor] = None, 160 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 161 | ): 162 | for resnet, attn in zip(self.resnets, self.attentions): 163 | # pop res hidden states 164 | #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}") 165 | res_hidden_states = res_hidden_states_tuple[-1] 166 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 167 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 168 | 169 | if self.training and self.gradient_checkpointing: 170 | 171 | def create_custom_forward(module, return_dict=None): 172 | def custom_forward(*inputs): 173 | if return_dict is not None: 174 | return module(*inputs, return_dict=return_dict) 175 | else: 176 | return module(*inputs) 177 | 178 | return custom_forward 179 | 180 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 181 | hidden_states = torch.utils.checkpoint.checkpoint( 182 | create_custom_forward(resnet), 183 | hidden_states, 184 | temb, 185 | **ckpt_kwargs, 186 | ) 187 | hidden_states = torch.utils.checkpoint.checkpoint( 188 | create_custom_forward(attn, return_dict=False), 189 | hidden_states, 190 | encoder_hidden_states, 191 | None, # timestep 192 | None, # class_labels 193 | cross_attention_kwargs, 194 | attention_mask, 195 | encoder_attention_mask, 196 | **ckpt_kwargs, 197 | )[0] 198 | else: 199 | hidden_states = resnet(hidden_states, temb) 200 | hidden_states = attn( 201 | hidden_states, 202 | encoder_hidden_states=encoder_hidden_states, 203 | cross_attention_kwargs=cross_attention_kwargs, 204 | attention_mask=attention_mask, 205 | encoder_attention_mask=encoder_attention_mask, 206 | return_dict=False, 207 | )[0] 208 | 209 | if self.upsamplers is not None: 210 | for upsampler in self.upsamplers: 211 | hidden_states = upsampler(hidden_states, upsample_size) 212 | 213 | return hidden_states 214 | 215 | return forward 216 | 217 | for i, upsample_block in enumerate(model.unet.up_blocks): 218 | if isinstance_str(upsample_block, "CrossAttnUpBlock2D"): 219 | upsample_block.forward = up_forward(upsample_block) 220 | 221 | 222 | def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): 223 | def up_forward(self): 224 | def forward( 225 | hidden_states: torch.FloatTensor, 226 | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], 227 | temb: Optional[torch.FloatTensor] = None, 228 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 229 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 230 | upsample_size: Optional[int] = None, 231 | attention_mask: Optional[torch.FloatTensor] = None, 232 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 233 | ): 234 | for resnet, attn in zip(self.resnets, self.attentions): 235 | # pop res hidden states 236 | #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}") 237 | res_hidden_states = res_hidden_states_tuple[-1] 238 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 239 | 240 | # --------------- FreeU code ----------------------- 241 | # Only operate on the first two stages 242 | if hidden_states.shape[1] == 1280: 243 | hidden_states[:,:640] = hidden_states[:,:640] * self.b1 244 | res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1) 245 | if hidden_states.shape[1] == 640: 246 | hidden_states[:,:320] = hidden_states[:,:320] * self.b2 247 | res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2) 248 | # --------------------------------------------------------- 249 | 250 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 251 | 252 | if self.training and self.gradient_checkpointing: 253 | 254 | def create_custom_forward(module, return_dict=None): 255 | def custom_forward(*inputs): 256 | if return_dict is not None: 257 | return module(*inputs, return_dict=return_dict) 258 | else: 259 | return module(*inputs) 260 | 261 | return custom_forward 262 | 263 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 264 | hidden_states = torch.utils.checkpoint.checkpoint( 265 | create_custom_forward(resnet), 266 | hidden_states, 267 | temb, 268 | **ckpt_kwargs, 269 | ) 270 | hidden_states = torch.utils.checkpoint.checkpoint( 271 | create_custom_forward(attn, return_dict=False), 272 | hidden_states, 273 | encoder_hidden_states, 274 | None, # timestep 275 | None, # class_labels 276 | cross_attention_kwargs, 277 | attention_mask, 278 | encoder_attention_mask, 279 | **ckpt_kwargs, 280 | )[0] 281 | else: 282 | hidden_states = resnet(hidden_states, temb) 283 | # hidden_states = attn( 284 | # hidden_states, 285 | # encoder_hidden_states=encoder_hidden_states, 286 | # cross_attention_kwargs=cross_attention_kwargs, 287 | # encoder_attention_mask=encoder_attention_mask, 288 | # return_dict=False, 289 | # )[0] 290 | hidden_states = attn( 291 | hidden_states, 292 | encoder_hidden_states=encoder_hidden_states, 293 | cross_attention_kwargs=cross_attention_kwargs, 294 | )[0] 295 | 296 | if self.upsamplers is not None: 297 | for upsampler in self.upsamplers: 298 | hidden_states = upsampler(hidden_states, upsample_size) 299 | 300 | return hidden_states 301 | 302 | return forward 303 | 304 | for i, upsample_block in enumerate(model.unet.up_blocks): 305 | if isinstance_str(upsample_block, "CrossAttnUpBlock2D"): 306 | upsample_block.forward = up_forward(upsample_block) 307 | setattr(upsample_block, 'b1', b1) 308 | setattr(upsample_block, 'b2', b2) 309 | setattr(upsample_block, 's1', s1) 310 | setattr(upsample_block, 's2', s2) 311 | -------------------------------------------------------------------------------- /FreeDrag--diffusion--version/utils/lora_utils.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | from PIL import Image 8 | import os 9 | import numpy as np 10 | from einops import rearrange 11 | import torch 12 | import torch.nn.functional as F 13 | from torchvision import transforms 14 | from accelerate import Accelerator 15 | from accelerate.utils import set_seed 16 | from PIL import Image 17 | 18 | from transformers import AutoTokenizer, PretrainedConfig 19 | 20 | import diffusers 21 | from diffusers import ( 22 | AutoencoderKL, 23 | DDPMScheduler, 24 | DiffusionPipeline, 25 | DPMSolverMultistepScheduler, 26 | StableDiffusionPipeline, 27 | UNet2DConditionModel, 28 | ) 29 | from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin 30 | from diffusers.models.attention_processor import ( 31 | AttnAddedKVProcessor, 32 | AttnAddedKVProcessor2_0, 33 | LoRAAttnAddedKVProcessor, 34 | LoRAAttnProcessor, 35 | LoRAAttnProcessor2_0, 36 | SlicedAttnAddedKVProcessor, 37 | ) 38 | from diffusers.optimization import get_scheduler 39 | from diffusers.utils import check_min_version 40 | from diffusers.utils.import_utils import is_xformers_available 41 | 42 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 43 | check_min_version("0.17.0") 44 | 45 | 46 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 47 | text_encoder_config = PretrainedConfig.from_pretrained( 48 | pretrained_model_name_or_path, 49 | subfolder="text_encoder", 50 | revision=revision, 51 | ) 52 | model_class = text_encoder_config.architectures[0] 53 | 54 | if model_class == "CLIPTextModel": 55 | from transformers import CLIPTextModel 56 | 57 | return CLIPTextModel 58 | elif model_class == "RobertaSeriesModelWithTransformation": 59 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 60 | 61 | return RobertaSeriesModelWithTransformation 62 | elif model_class == "T5EncoderModel": 63 | from transformers import T5EncoderModel 64 | 65 | return T5EncoderModel 66 | else: 67 | raise ValueError(f"{model_class} is not supported.") 68 | 69 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): 70 | if tokenizer_max_length is not None: 71 | max_length = tokenizer_max_length 72 | else: 73 | max_length = tokenizer.model_max_length 74 | 75 | text_inputs = tokenizer( 76 | prompt, 77 | truncation=True, 78 | padding="max_length", 79 | max_length=max_length, 80 | return_tensors="pt", 81 | ) 82 | 83 | return text_inputs 84 | 85 | def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False): 86 | text_input_ids = input_ids.to(text_encoder.device) 87 | 88 | if text_encoder_use_attention_mask: 89 | attention_mask = attention_mask.to(text_encoder.device) 90 | else: 91 | attention_mask = None 92 | 93 | prompt_embeds = text_encoder( 94 | text_input_ids, 95 | attention_mask=attention_mask, 96 | ) 97 | prompt_embeds = prompt_embeds[0] 98 | 99 | return prompt_embeds 100 | 101 | # model_path: path of the model 102 | # image: input image, have not been pre-processed 103 | # save_lora_path: the path to save the lora 104 | # prompt: the user input prompt 105 | # lora_step: number of lora training step 106 | # lora_lr: learning rate of lora training 107 | # lora_rank: the rank of lora 108 | # save_interval: the frequency of saving lora checkpoints 109 | def train_lora(image, 110 | prompt, 111 | model_path, 112 | vae_path, 113 | save_lora_path, 114 | lora_step, 115 | lora_lr, 116 | lora_batch_size, 117 | lora_rank, 118 | progress, 119 | save_interval=-1): 120 | # initialize accelerator 121 | accelerator = Accelerator( 122 | gradient_accumulation_steps=1, 123 | mixed_precision='fp16' 124 | ) 125 | set_seed(0) 126 | 127 | # Load the tokenizer 128 | tokenizer = AutoTokenizer.from_pretrained( 129 | model_path, 130 | subfolder="tokenizer", 131 | revision=None, 132 | use_fast=False, 133 | ) 134 | # initialize the model 135 | noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler") 136 | text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None) 137 | text_encoder = text_encoder_cls.from_pretrained( 138 | model_path, subfolder="text_encoder", revision=None 139 | ) 140 | if vae_path == "default": 141 | vae = AutoencoderKL.from_pretrained( 142 | model_path, subfolder="vae", revision=None 143 | ) 144 | else: 145 | vae = AutoencoderKL.from_pretrained(vae_path) 146 | unet = UNet2DConditionModel.from_pretrained( 147 | model_path, subfolder="unet", revision=None 148 | ) 149 | 150 | # set device and dtype 151 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 152 | 153 | vae.requires_grad_(False) 154 | text_encoder.requires_grad_(False) 155 | unet.requires_grad_(False) 156 | 157 | unet.to(device, dtype=torch.float16) 158 | vae.to(device, dtype=torch.float16) 159 | text_encoder.to(device, dtype=torch.float16) 160 | 161 | # initialize UNet LoRA 162 | unet_lora_attn_procs = {} 163 | for name, attn_processor in unet.attn_processors.items(): 164 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 165 | if name.startswith("mid_block"): 166 | hidden_size = unet.config.block_out_channels[-1] 167 | elif name.startswith("up_blocks"): 168 | block_id = int(name[len("up_blocks.")]) 169 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 170 | elif name.startswith("down_blocks"): 171 | block_id = int(name[len("down_blocks.")]) 172 | hidden_size = unet.config.block_out_channels[block_id] 173 | else: 174 | raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks") 175 | 176 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): 177 | lora_attn_processor_class = LoRAAttnAddedKVProcessor 178 | else: 179 | lora_attn_processor_class = ( 180 | LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor 181 | ) 182 | unet_lora_attn_procs[name] = lora_attn_processor_class( 183 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank 184 | ) 185 | 186 | unet.set_attn_processor(unet_lora_attn_procs) 187 | unet_lora_layers = AttnProcsLayers(unet.attn_processors) 188 | 189 | # Optimizer creation 190 | params_to_optimize = (unet_lora_layers.parameters()) 191 | optimizer = torch.optim.AdamW( 192 | params_to_optimize, 193 | lr=lora_lr, 194 | betas=(0.9, 0.999), 195 | weight_decay=1e-2, 196 | eps=1e-08, 197 | ) 198 | 199 | lr_scheduler = get_scheduler( 200 | "constant", 201 | optimizer=optimizer, 202 | num_warmup_steps=0, 203 | num_training_steps=lora_step, 204 | num_cycles=1, 205 | power=1.0, 206 | ) 207 | 208 | # prepare accelerator 209 | unet_lora_layers = accelerator.prepare_model(unet_lora_layers) 210 | optimizer = accelerator.prepare_optimizer(optimizer) 211 | lr_scheduler = accelerator.prepare_scheduler(lr_scheduler) 212 | 213 | # initialize text embeddings 214 | with torch.no_grad(): 215 | text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None) 216 | text_embedding = encode_prompt( 217 | text_encoder, 218 | text_inputs.input_ids, 219 | text_inputs.attention_mask, 220 | text_encoder_use_attention_mask=False 221 | ) 222 | text_embedding = text_embedding.repeat(lora_batch_size, 1, 1) 223 | 224 | # initialize latent distribution 225 | image_transforms = transforms.Compose( 226 | [ 227 | transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), 228 | transforms.RandomCrop(512), 229 | transforms.ToTensor(), 230 | transforms.Normalize([0.5], [0.5]), 231 | ] 232 | ) 233 | 234 | for step in progress.tqdm(range(lora_step), desc="training LoRA"): 235 | unet.train() 236 | image_batch = [] 237 | for _ in range(lora_batch_size): 238 | image_transformed = image_transforms(Image.fromarray(image)).to(device, dtype=torch.float16) 239 | image_transformed = image_transformed.unsqueeze(dim=0) 240 | image_batch.append(image_transformed) 241 | 242 | # repeat the image_transformed to enable multi-batch training 243 | image_batch = torch.cat(image_batch, dim=0) 244 | 245 | latents_dist = vae.encode(image_batch).latent_dist 246 | model_input = latents_dist.sample() * vae.config.scaling_factor 247 | # Sample noise that we'll add to the latents 248 | noise = torch.randn_like(model_input) 249 | bsz, channels, height, width = model_input.shape 250 | # Sample a random timestep for each image 251 | timesteps = torch.randint( 252 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device 253 | ) 254 | timesteps = timesteps.long() 255 | 256 | # Add noise to the model input according to the noise magnitude at each timestep 257 | # (this is the forward diffusion process) 258 | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) 259 | 260 | # Predict the noise residual 261 | model_pred = unet(noisy_model_input, timesteps, text_embedding).sample 262 | 263 | # Get the target for loss depending on the prediction type 264 | if noise_scheduler.config.prediction_type == "epsilon": 265 | target = noise 266 | elif noise_scheduler.config.prediction_type == "v_prediction": 267 | target = noise_scheduler.get_velocity(model_input, noise, timesteps) 268 | else: 269 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 270 | 271 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 272 | accelerator.backward(loss) 273 | optimizer.step() 274 | lr_scheduler.step() 275 | optimizer.zero_grad() 276 | 277 | if save_interval > 0 and (step + 1) % save_interval == 0: 278 | save_lora_path_intermediate = os.path.join(save_lora_path, str(step+1)) 279 | if not os.path.isdir(save_lora_path_intermediate): 280 | os.mkdir(save_lora_path_intermediate) 281 | # unet = unet.to(torch.float32) 282 | # unwrap_model is used to remove all special modules added when doing distributed training 283 | # so here, there is no need to call unwrap_model 284 | # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) 285 | LoraLoaderMixin.save_lora_weights( 286 | save_directory=save_lora_path_intermediate, 287 | unet_lora_layers=unet_lora_layers, 288 | text_encoder_lora_layers=None, 289 | ) 290 | # unet = unet.to(torch.float16) 291 | 292 | # save the trained lora 293 | # unet = unet.to(torch.float32) 294 | # unwrap_model is used to remove all special modules added when doing distributed training 295 | # so here, there is no need to call unwrap_model 296 | # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) 297 | LoraLoaderMixin.save_lora_weights( 298 | save_directory=save_lora_path, 299 | unet_lora_layers=unet_lora_layers, 300 | text_encoder_lora_layers=None, 301 | ) 302 | 303 | return 304 | -------------------------------------------------------------------------------- /FreeDrag_gradio.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torch 3 | import numpy as np 4 | from functions import to_image, draw_handle_target_points, free_drag, image_inversion, add_watermark_np 5 | import dnnlib 6 | from training import networks 7 | import legacy 8 | import cv2 9 | 10 | # export CUDA_LAUNCH_BLOCKING=1 11 | def load_model(model_name, device): 12 | 13 | path = './checkpoints/' + str(model_name) 14 | with dnnlib.util.open_url(path) as f: 15 | G = legacy.load_network_pkl(f)['G_ema'].to(device) 16 | G_copy = networks.Generator(z_dim=G.z_dim, c_dim= G.c_dim, w_dim =G.w_dim, 17 | img_resolution = G.img_resolution, 18 | img_channels = G.img_channels, 19 | mapping_kwargs = G.init_kwargs['mapping_kwargs']) 20 | 21 | G_copy.load_state_dict(G.state_dict()) 22 | G_copy.to(device) 23 | del(G) 24 | for param in G_copy.parameters(): 25 | param.requires_grad = False 26 | return G_copy, model_name 27 | 28 | def draw_mask(image,mask): 29 | 30 | image_mask = image*(1-mask) +mask*(0.7*image+0.3*255.0) 31 | 32 | return image_mask 33 | 34 | 35 | class ModelWrapper: 36 | def __init__(self, model,model_name): 37 | self.g = model 38 | self.name = model_name 39 | self.res = CKPT_SIZE[model_name][0] 40 | self.l = CKPT_SIZE[model_name][1] 41 | self.d = CKPT_SIZE[model_name][2] 42 | 43 | 44 | # model, points, mask, feature_size, train_layer_index,max_step, device,seed=2023,max_distance=3, d=0.5 45 | # img_show, current_target, step_number 46 | def on_drag(model, points, mask, max_iters,latent,sample_interval,l_expected,d_max,save_video,add_watermark): 47 | 48 | if len(points['handle']) == 0: 49 | raise gr.Error('You must select at least one handle point and target point.') 50 | if len(points['handle']) != len(points['target']): 51 | raise gr.Error('You have uncompleted handle points, try to selct a target point or undo the handle point.') 52 | max_iters = int(max_iters) 53 | 54 | handle_size = 128 55 | train_layer_index=6 56 | l_expected = torch.tensor(l_expected,device=latent.device) 57 | d_max = torch.tensor(d_max,device=latent.device) 58 | mask[mask>0] = 1 59 | global stop_flag 60 | stop_flag = False 61 | images_total = [] 62 | for img_show, current_target, step_number,full_res, latent_optimized in free_drag(model.g,points,mask[:,:,0],handle_size, \ 63 | train_layer_index,latent,max_iters,l_expected,d_max,sample_interval,device=latent.device): 64 | image = to_image(img_show) 65 | 66 | points['handle'] = [current_target[p,:].cpu().numpy().astype('int') for p in range(len(current_target[:,0]))] 67 | image_show = add_points_to_image(image, points, size=RES_TO_CLICK_SIZE[full_res],color="yellow") 68 | 69 | if np.any(mask[:,:,0]>0): 70 | image_show = draw_mask(image_show,mask) 71 | image_show = np.uint8(image_show) 72 | 73 | if add_watermark: 74 | image_show = add_watermark_np(np.array(image_show))[:,:,[0,1,2]] 75 | image_clear = add_watermark_np(np.array(image))[:,:,[0,1,2]] 76 | else: 77 | image_clear = image 78 | 79 | if save_video: 80 | images_total.append(image_show) 81 | yield (image_show, step_number, latent_optimized,image_clear,images_total,gr.Button.update(interactive=True)) 82 | 83 | if stop_flag: 84 | break 85 | 86 | def add_points_to_image(image, points, size=5,color="red"): 87 | image = draw_handle_target_points(image, points['handle'], points['target'], size, color) 88 | return image 89 | 90 | def on_show_save(): 91 | return gr.update(visible=True) 92 | 93 | def on_click(image, target_point, points, res, evt: gr.SelectData): 94 | if target_point: 95 | points['target'].append([evt.index[1], evt.index[0]]) 96 | image = add_points_to_image(image, points, size=RES_TO_CLICK_SIZE[res]) 97 | return image, not target_point 98 | points['handle'].append([evt.index[1], evt.index[0]]) 99 | image = add_points_to_image(image, points, size=RES_TO_CLICK_SIZE[res]) 100 | return image, not target_point 101 | 102 | def new_image(model,seed=-1): 103 | if seed == -1: 104 | seed = np.random.randint(1,1e6) 105 | z1 = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.g.z_dim)).to(device) 106 | label = torch.zeros([1, model.g.c_dim], device=device) 107 | ws_original= model.g.get_ws(z1,label,truncation_psi=0.7) 108 | _, img_show_original = model.g.synthesis(ws=ws_original,noise_mode='const') 109 | 110 | return to_image(img_show_original), to_image(img_show_original), ws_original, seed 111 | 112 | def new_model(model_name): 113 | model_load, _ = load_model(model_name, device) 114 | model = ModelWrapper(model_load,model_name) 115 | 116 | return model, model.res, model.l, model.d 117 | 118 | def reset_all(image,mask,add_watermark=0): 119 | points = {'target': [], 'handle': []} 120 | target_point = False 121 | mask = np.zeros_like(mask,dtype=np.uint8) 122 | 123 | return points, target_point, image, None,mask, add_watermark 124 | 125 | def add_mask(image_show,mask): 126 | image_show = draw_mask(image_show,mask) 127 | return image_show 128 | 129 | def update_mask(image,mask_show): 130 | mask = np.zeros_like(image) 131 | if mask_show != None and np.any(mask_show['mask'][:,:,0]>1): 132 | mask[mask_show['mask'][:,:,:3]>0] =1 133 | image_mask = add_mask(image,mask) 134 | return np.uint8(image_mask), mask 135 | else: 136 | return image, mask 137 | 138 | def on_select_mask_tab(image): 139 | return image 140 | 141 | def change_stop_state(): 142 | global stop_flag 143 | stop_flag = True 144 | 145 | def save_video(imgs_show_list,frame): 146 | if len(imgs_show_list)>0: 147 | video_name = './process.mp4' 148 | fource = cv2.VideoWriter_fourcc(*'mp4v') 149 | full_res = imgs_show_list[0].shape[0] 150 | video_output = cv2.VideoWriter(video_name,fourcc=fource,fps=frame,frameSize = (full_res,full_res)) 151 | for k in range(len(imgs_show_list)): 152 | video_output.write(imgs_show_list[k][:,:,::-1]) 153 | video_output.release() 154 | return [] 155 | 156 | CKPT_SIZE = { 157 | 'faces.pkl':[512, 0.3, 3], 158 | 'horses.pkl': [256, 0.3, 3], 159 | 'elephants.pkl': [512, 0.4, 4], 160 | 'lions.pkl':[512, 0.4, 4], 161 | 'dogs.pkl':[1024, 0.4, 4], 162 | 'bicycles.pkl':[256, 0.3, 3], 163 | 'giraffes.pkl':[512, 0.4, 4], 164 | 'cats.pkl':[512, 0.3, 3], 165 | 'cars.pkl':[512, 0.3, 3], 166 | 'churches.pkl':[256, 0.3, 3], 167 | 'metfaces.pkl':[1024, 0.3, 3], 168 | } 169 | RES_TO_CLICK_SIZE = { 170 | 1024: 10, 171 | 512: 5, 172 | 256: 3, 173 | } 174 | 175 | if torch.cuda.is_available(): 176 | device = 'cuda' 177 | else: 178 | device = 'cpu' 179 | 180 | demo = gr.Blocks() 181 | 182 | with demo: 183 | 184 | points = gr.State({'target': [], 'handle': []}) 185 | target_point = gr.State(False) 186 | state = gr.State({}) 187 | 188 | gr.Markdown( 189 | """ 190 | # **FreeDrag** 191 | 192 | Official implementation of [FreeDrag: Point Tracking is Not What You Need for Interactive Point-based Image Editing](https://github.com/LPengYang/FreeDrag) 193 | 194 | 195 | ## Parameter Description 196 | **max_step**: max number of optimization step 197 | 198 | **sample_interval**: the interval between sampled optimization step. 199 | This parameter only affects the visualization of intermediate results and does not have any impact on the final outcome. 200 | For high-resolution images(such as model of dog), a larger sample_interval can significantly accelerate the dragging process. 201 | 202 | **Eepected initial loss and Max distance**: In the current version, both of these values are empirically set for each model. 203 | Generally, for precise editing needs (e.g., merging eyes), smaller values are recommended, which may causes longer processing times. 204 | Users can set these values according to practical editing requirements. We are currently seeking an automated solution. 205 | 206 | **frame_rate**: the frame rate for saved video. 207 | 208 | ## Hints 209 | - Handle points (Blue): the point you want to drag. 210 | - Target points (Red): the destination you want to drag towards to. 211 | - **Localized points (Yellow)**: the localized points in sub-motion 212 | """, 213 | ) 214 | 215 | with gr.Row(): 216 | with gr.Column(scale=0.4): 217 | with gr.Accordion("Model"): 218 | with gr.Row(): 219 | with gr.Column(min_width=100): 220 | seed = gr.Number(label='Seed',value=0) 221 | with gr.Column(min_width=100): 222 | button_new = gr.Button('Image Generate', variant='primary') 223 | button_rand = gr.Button('Rand Generate') 224 | model_name = gr.Dropdown(label="Model name",choices=list(CKPT_SIZE.keys()),value = list(CKPT_SIZE.keys())[0]) 225 | 226 | with gr.Accordion('Optional Parameters'): 227 | with gr.Row(): 228 | with gr.Column(min_width=100): 229 | max_step = gr.Number(label='Max step',value=2000) 230 | with gr.Column(min_width=100): 231 | sample_interval = gr.Number(label='Interval',value=5,info="Sampling interval") 232 | 233 | model_load, _ = load_model(model_name.value, device) 234 | model = gr.State(ModelWrapper(model_load,model_name.value)) 235 | l_expected = gr.Slider(0.1,0.5,label='Expected initial loss for each sub-motion',value = model.value.l,step=0.05) 236 | d_max= gr.Slider(1.0,6.0,label='Max distance for each sub-motion (in the feature map)',value = model.value.d,step=0.5) 237 | 238 | res = gr.State(model.value.res) 239 | z1 = torch.from_numpy(np.random.RandomState(int(seed.value)).randn(1, model.value.g.z_dim)).to(device) 240 | label = torch.zeros([1, model.value.g.c_dim], device=device) 241 | ws_original= model.value.g.get_ws(z1,label,truncation_psi=0.7) 242 | latent = gr.State(ws_original) 243 | add_watermark = gr.State(torch.zeros(1,device=device)) 244 | 245 | _, img_show_original = model.value.g.synthesis(ws=ws_original,noise_mode='const') 246 | 247 | with gr.Accordion('Video'): 248 | images_total = gr.State([]) 249 | with gr.Row(): 250 | with gr.Column(min_width=100): 251 | if_save_video = gr.Radio(["True","False"],value="False",label="if save video") 252 | with gr.Column(min_width=100): 253 | frame_rate = gr.Number(label="Frame rate",value=5) 254 | with gr.Row(): 255 | with gr.Column(min_width=100): 256 | button_video = gr.Button('Save video', variant='primary') 257 | 258 | with gr.Accordion('Drag'): 259 | 260 | with gr.Row(): 261 | with gr.Column(min_width=200): 262 | reset_btn = gr.Button('Reset points and mask') 263 | with gr.Row(): 264 | with gr.Column(min_width=100): 265 | button_drag = gr.Button('Drag it', variant='primary') 266 | with gr.Column(min_width=100): 267 | button_stop = gr.Button('Stop') 268 | 269 | progress = gr.Number(value=0, label='Steps', interactive=False) 270 | 271 | with gr.Column(scale=0.53): 272 | with gr.Tabs() as Tabs: 273 | image_show = to_image(img_show_original) 274 | image_clear = gr.State(image_show) 275 | mask = gr.State(np.zeros_like(image_clear.value)) 276 | with gr.Tab('Setup Handle Points', id='input') as imagetab: 277 | image = gr.Image(image_show).style(height=768, width=768) 278 | with gr.Tab('Draw a Mask', id='mask') as masktab: 279 | mask_show = gr.ImageMask(image_show).style(height=768, width=768) 280 | 281 | image.select(on_click, [image, target_point, points, res], [image, target_point]).then(on_show_save) 282 | 283 | image.upload(image_inversion,[image,res,model],[latent,image,image_clear,add_watermark]).then(reset_all, 284 | inputs=[image_clear,mask,add_watermark],outputs=[points,target_point,image,mask_show,mask,add_watermark]) 285 | 286 | button_drag.click(on_drag, inputs=[model, points, mask, max_step,latent,sample_interval,l_expected,d_max,if_save_video,add_watermark], \ 287 | outputs=[image, progress, latent, image_clear, images_total, button_stop]) 288 | button_stop.click(change_stop_state) 289 | 290 | button_video.click(save_video,inputs=[images_total,frame_rate],outputs=[images_total]) 291 | reset_btn.click(reset_all,inputs=[image_clear,mask,add_watermark],outputs= [points,target_point,image,mask_show,mask,add_watermark]).then(on_show_save) 292 | 293 | button_new.click(new_image, inputs = [model,seed],outputs = [image, image_clear, latent,seed]).then(reset_all, 294 | inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask,add_watermark]) 295 | 296 | button_rand.click(new_image, inputs = [model],outputs = [image, image_clear, latent,seed]).then(reset_all, 297 | inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask,add_watermark]) 298 | 299 | model_name.change(new_model,inputs=[model_name],outputs=[model,res,l_expected,d_max]).then \ 300 | (new_image, inputs = [model,seed],outputs = [image, image_clear, latent,seed]).then \ 301 | (reset_all,inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask,add_watermark]) 302 | 303 | imagetab.select(update_mask,[image,mask_show],[image,mask]) 304 | masktab.select(on_select_mask_tab, inputs=[image], outputs=[mask_show]) 305 | 306 | 307 | if __name__ == "__main__": 308 | 309 | demo.queue(concurrency_count=3,max_size=20).launch(share=True) 310 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | # FreeDrag: Feature Dragging for Reliable Point-based Image Editing 6 |

7 | 8 | 9 |

10 | 11 | ## Visualization 12 | [![]](https://user-images.githubusercontent.com/58554846/253733958-c97629a0-5928-476b-99f2-79d5f92762e7.mp4) 13 | 14 | 15 | ## Web Demo (online dragging editing in 11 different StyleGAN2 models) 16 | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/LPengYang/FreeDrag) 17 | 18 | Official implementation of **FreeDrag: Feature Dragging for Reliable Point-based Image Editing**. 19 | - *Authors*: Pengyang Ling*, [Lin Chen*](https://lin-chen.site), [Pan Zhang](https://panzhang0212.github.io/), Huaian Chen, Yi Jin, Jinjin Zheng, 20 | - *Institutes*: University of Science and Technology of China; Shanghai AI Laboratory 21 | - [[Paper]](https://arxiv.org/abs/2307.04684) [[Project Page]](https://lin-chen.site/projects/freedrag) [[Web Demo]](https://openxlab.org.cn/apps/detail/LPengYang/FreeDrag) 22 | 23 | This repo proposes FreeDrag, a novel interactive point-based image editing framework free of the laborious and unstable point tracking process🔥🔥🔥. 24 | 25 | 26 | ## Abstract 27 | To serve the intricate and varied demands of image editing, precise and flexible manipulation in image content is indispensable. Recently, Drag-based editing methods have gained impressive performance. However, these methods predominantly center on point dragging, resulting in two noteworthy drawbacks, namely "miss tracking", where difficulties arise in accurately tracking the predetermined handle points, and "ambiguous tracking", where tracked points are potentially positioned in wrong regions that closely resemble the handle points. To address the above issues, we propose **FreeDrag**, a feature dragging methodology designed to free the burden on point tracking. The **FreeDrag** incorporates two key designs, i.e., template feature via adaptive updating and line search with backtracking, the former improves the stability against drastic content change by elaborately controls feature updating scale after each dragging, while the latter alleviates the misguidance from similar points by actively restricting the search area in a line. These two technologies together contribute to a more stable semantic dragging with higher efficiency. Comprehensive experimental results substantiate that our approach significantly outperforms pre-existing methodologies, offering reliable point-based editing even in various complex scenarios. 28 | 29 |

30 | 31 | 32 |

33 | 34 | ## 📜 News 35 | [2024/03/06] [FreeDrag](https://arxiv.org/abs/2307.04684) is accepted by CVPR 2024. 36 | 37 | [2023/12/11] The updataed [FreeDrag](https://arxiv.org/abs/2307.04684) containing the implementations in both StyleGAN and Diffusion models is available now. 38 | 39 | [2023/12/8] FreeDrag based on diffusion model is available now, which support dragging editing in both real images and generated images. 40 | 41 | [2023/7/31] The web demo (StyleGAN) in [OpenXLab](https://openxlab.org.cn/apps/detail/LPengYang/FreeDrag) is available now. 42 | 43 | [2023/7/28] The function of real image editing is available now. 44 | 45 | [2023/7/15] Code of local demo is available now!💥 46 | 47 | [2023/7/11] The [paper](https://arxiv.org/abs/2307.04684) and [project page](https://lin-chen.site/projects/freedrag) are released! 48 | 49 | ## 💡 Highlights 50 | - [x] Local demo of FreeDrag 51 | - [x] Web demo of FreeDrag 52 | - [x] Diffusion-based FreeDrag 53 | 54 | ## 🛠️Usage 55 | 56 | First clone our repository 57 | ``` 58 | git clone --depth=1 https://github.com/LPengYang/FreeDrag 59 | ``` 60 | To create a new environment, please follow the requirements of [NVlabs/stylegan2-ada](https://github.com/NVlabs/stylegan2-ada-pytorch#requirements). 61 | 62 | **Notice:** It is observed that the errors (setting up PyTorch plugin “bias_act_plugin“... Failed or “upfirdn2d_plugin“... Failed) may appear in some devices, we hope these potential solutions ([1](https://blog.csdn.net/qq_15969343/article/details/129190607), [2](https://github.com/NVlabs/stylegan2-ada-pytorch/issues/155), [3](https://github.com/NVlabs/stylegan3/issues/124), [4](https://github.com/XingangPan/DragGAN/issues/106)) could be helpful in this case. 63 | 64 | Then install the additional requirements 65 | 66 | ``` 67 | pip install -r requirements.txt 68 | ``` 69 | 70 | Then download the pre-trained models of stylegan2 71 | ``` 72 | bash download_models.sh 73 | ``` 74 | **Notice:** The first model (face model) could be downloaded very slowly in some cases. In this case, it is suggested to restart the download (works sometimes) or directly download it from this [link](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/research/models/stylegan2/files), please download the correct model (ffhq-512×512) and renamed it as "faces.pkl" and manually put it in the created checkpoints file (after all the other models are downloaded). 75 | 76 | Finally initialize the gradio platform for interactive point-based manipulation 77 | 78 | ``` 79 | CUDA_LAUNCH_BLOCKING=1 python FreeDrag_gradio.py 80 | ``` 81 | You can also upload your images and then edit them. For a high-quality image inversion, it is suggested to make sure that the resolution and style (such as layout) of the uploaded images are consistent with the generated images of corresponding model. The resolution of different model is listed as follows: 82 | 83 | |Model|face|horse|elephant|lion|dog|bicycle|giraffe|cat|car|church|metface| 84 | |:----:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 85 | |Resolution|512|256|512|512|1024|256|512|512|512|256|1024| 86 | 87 | The proposed **FreeDragBench Dataset** is available on the [website](https://drive.google.com/file/d/1p2muR6aW6fqEGW8yTcHl86DCuUkwgtNY/view?usp=sharing). 88 | 89 | ## ❤️Acknowledgments 90 | - [DragGAN](https://github.com/XingangPan/DragGAN/) 91 | - [DragDiffusion](https://yujun-shi.github.io/projects/dragdiffusion.html) 92 | - [StyleGAN2](https://github.com/NVlabs/stylegan2-ada-pytorch) 93 | 94 | ## License 95 | All codes used or modified from [StyleGAN2](https://github.com/NVlabs/stylegan2-ada-pytorch) are under the [Nvidia Source Code License](https://github.com/NVlabs/stylegan3/blob/main/LICENSE.txt). 96 | The code related to the FreeDrag algorithm is only allowed for personal activity. The diffusion-based FreeDrag is implemented based on [DragDiffusion](https://yujun-shi.github.io/projects/dragdiffusion.html). 97 | 98 | ## ✒️ Citation 99 | If you find our work helpful for your research, please consider citing the following BibTeX entry. 100 | ```bibtex 101 | @inproceedings{ling2024freedrag, 102 | title={Freedrag: Feature dragging for reliable point-based image editing}, 103 | author={Ling, Pengyang and Chen, Lin and Zhang, Pan and Chen, Huaian and Jin, Yi and Zheng, Jinjin}, 104 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 105 | pages={6860--6870}, 106 | year={2024} 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/arial.ttf -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /download_models.sh: -------------------------------------------------------------------------------- 1 | mkdir checkpoints 2 | cd checkpoints 3 | rm * 4 | wget 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl' 5 | mv stylegan2-ffhq-512x512.pkl faces.pkl 6 | curl -o cats.pkl https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqcat.pkl 7 | curl -o lions.pkl https://storage.googleapis.com/self-distilled-stylegan/lions_512_pytorch.pkl 8 | curl -o dogs.pkl https://storage.googleapis.com/self-distilled-stylegan/dogs_1024_pytorch.pkl 9 | curl -o horses.pkl https://storage.googleapis.com/self-distilled-stylegan/horses_256_pytorch.pkl 10 | curl -o elephants.pkl https://storage.googleapis.com/self-distilled-stylegan/elephants_512_pytorch.pkl 11 | curl -o bicycles.pkl https://storage.googleapis.com/self-distilled-stylegan/bicycles_256_pytorch.pkl 12 | curl -o giraffes.pkl https://storage.googleapis.com/self-distilled-stylegan/giraffes_512_pytorch.pkl 13 | curl -o cars.pkl http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl 14 | curl -o churches.pkl http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-church-config-f.pkl 15 | curl -o metfaces.pkl https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 16 | -------------------------------------------------------------------------------- /legacy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import click 10 | import pickle 11 | import re 12 | import copy 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | from torch_utils import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def load_network_pkl(f, force_fp16=False): 21 | data = _LegacyUnpickler(f).load() 22 | 23 | # Legacy TensorFlow pickle => convert. 24 | if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): 25 | tf_G, tf_D, tf_Gs = data 26 | G = convert_tf_generator(tf_G) 27 | D = convert_tf_discriminator(tf_D) 28 | G_ema = convert_tf_generator(tf_Gs) 29 | data = dict(G=G, D=D, G_ema=G_ema) 30 | 31 | # Add missing fields. 32 | if 'training_set_kwargs' not in data: 33 | data['training_set_kwargs'] = None 34 | if 'augment_pipe' not in data: 35 | data['augment_pipe'] = None 36 | 37 | # Validate contents. 38 | assert isinstance(data['G'], torch.nn.Module) 39 | assert isinstance(data['D'], torch.nn.Module) 40 | assert isinstance(data['G_ema'], torch.nn.Module) 41 | assert isinstance(data['training_set_kwargs'], (dict, type(None))) 42 | assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) 43 | 44 | # Force FP16. 45 | if force_fp16: 46 | for key in ['G', 'D', 'G_ema']: 47 | old = data[key] 48 | kwargs = copy.deepcopy(old.init_kwargs) 49 | if key.startswith('G'): 50 | kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {})) 51 | kwargs.synthesis_kwargs.num_fp16_res = 4 52 | kwargs.synthesis_kwargs.conv_clamp = 256 53 | if key.startswith('D'): 54 | kwargs.num_fp16_res = 4 55 | kwargs.conv_clamp = 256 56 | if kwargs != old.init_kwargs: 57 | new = type(old)(**kwargs).eval().requires_grad_(False) 58 | misc.copy_params_and_buffers(old, new, require_all=True) 59 | data[key] = new 60 | return data 61 | 62 | #---------------------------------------------------------------------------- 63 | 64 | class _TFNetworkStub(dnnlib.EasyDict): 65 | pass 66 | 67 | class _LegacyUnpickler(pickle.Unpickler): 68 | def find_class(self, module, name): 69 | if module == 'dnnlib.tflib.network' and name == 'Network': 70 | return _TFNetworkStub 71 | return super().find_class(module, name) 72 | 73 | #---------------------------------------------------------------------------- 74 | 75 | def _collect_tf_params(tf_net): 76 | # pylint: disable=protected-access 77 | tf_params = dict() 78 | def recurse(prefix, tf_net): 79 | for name, value in tf_net.variables: 80 | tf_params[prefix + name] = value 81 | for name, comp in tf_net.components.items(): 82 | recurse(prefix + name + '/', comp) 83 | recurse('', tf_net) 84 | return tf_params 85 | 86 | #---------------------------------------------------------------------------- 87 | 88 | def _populate_module_params(module, *patterns): 89 | for name, tensor in misc.named_params_and_buffers(module): 90 | found = False 91 | value = None 92 | for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): 93 | match = re.fullmatch(pattern, name) 94 | if match: 95 | found = True 96 | if value_fn is not None: 97 | value = value_fn(*match.groups()) 98 | break 99 | try: 100 | assert found 101 | if value is not None: 102 | tensor.copy_(torch.from_numpy(np.array(value))) 103 | except: 104 | print(name, list(tensor.shape)) 105 | raise 106 | 107 | #---------------------------------------------------------------------------- 108 | 109 | def convert_tf_generator(tf_G): 110 | if tf_G.version < 4: 111 | raise ValueError('TensorFlow pickle version too low') 112 | 113 | # Collect kwargs. 114 | tf_kwargs = tf_G.static_kwargs 115 | known_kwargs = set() 116 | def kwarg(tf_name, default=None, none=None): 117 | known_kwargs.add(tf_name) 118 | val = tf_kwargs.get(tf_name, default) 119 | return val if val is not None else none 120 | 121 | # Convert kwargs. 122 | kwargs = dnnlib.EasyDict( 123 | z_dim = kwarg('latent_size', 512), 124 | c_dim = kwarg('label_size', 0), 125 | w_dim = kwarg('dlatent_size', 512), 126 | img_resolution = kwarg('resolution', 1024), 127 | img_channels = kwarg('num_channels', 3), 128 | mapping_kwargs = dnnlib.EasyDict( 129 | num_layers = kwarg('mapping_layers', 8), 130 | embed_features = kwarg('label_fmaps', None), 131 | layer_features = kwarg('mapping_fmaps', None), 132 | activation = kwarg('mapping_nonlinearity', 'lrelu'), 133 | lr_multiplier = kwarg('mapping_lrmul', 0.01), 134 | w_avg_beta = kwarg('w_avg_beta', 0.995, none=1), 135 | ), 136 | synthesis_kwargs = dnnlib.EasyDict( 137 | channel_base = kwarg('fmap_base', 16384) * 2, 138 | channel_max = kwarg('fmap_max', 512), 139 | num_fp16_res = kwarg('num_fp16_res', 0), 140 | conv_clamp = kwarg('conv_clamp', None), 141 | architecture = kwarg('architecture', 'skip'), 142 | resample_filter = kwarg('resample_kernel', [1,3,3,1]), 143 | use_noise = kwarg('use_noise', True), 144 | activation = kwarg('nonlinearity', 'lrelu'), 145 | ), 146 | ) 147 | 148 | # Check for unknown kwargs. 149 | kwarg('truncation_psi') 150 | kwarg('truncation_cutoff') 151 | kwarg('style_mixing_prob') 152 | kwarg('structure') 153 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) 154 | if len(unknown_kwargs) > 0: 155 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) 156 | 157 | # Collect params. 158 | tf_params = _collect_tf_params(tf_G) 159 | for name, value in list(tf_params.items()): 160 | match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) 161 | if match: 162 | r = kwargs.img_resolution // (2 ** int(match.group(1))) 163 | tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value 164 | kwargs.synthesis.kwargs.architecture = 'orig' 165 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') 166 | 167 | # Convert params. 168 | from training import networks 169 | G = networks.Generator(**kwargs).eval().requires_grad_(False) 170 | # pylint: disable=unnecessary-lambda 171 | _populate_module_params(G, 172 | r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], 173 | r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), 174 | r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], 175 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), 176 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], 177 | r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], 178 | r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), 179 | r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], 180 | r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], 181 | r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], 182 | r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), 183 | r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, 184 | r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1), 185 | r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], 186 | r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0], 187 | r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], 188 | r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(), 189 | r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, 190 | r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1), 191 | r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], 192 | r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0], 193 | r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], 194 | r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), 195 | r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, 196 | r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1), 197 | r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], 198 | r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), 199 | r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, 200 | r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1), 201 | r'.*\.resample_filter', None, 202 | ) 203 | return G 204 | 205 | #---------------------------------------------------------------------------- 206 | 207 | def convert_tf_discriminator(tf_D): 208 | if tf_D.version < 4: 209 | raise ValueError('TensorFlow pickle version too low') 210 | 211 | # Collect kwargs. 212 | tf_kwargs = tf_D.static_kwargs 213 | known_kwargs = set() 214 | def kwarg(tf_name, default=None): 215 | known_kwargs.add(tf_name) 216 | return tf_kwargs.get(tf_name, default) 217 | 218 | # Convert kwargs. 219 | kwargs = dnnlib.EasyDict( 220 | c_dim = kwarg('label_size', 0), 221 | img_resolution = kwarg('resolution', 1024), 222 | img_channels = kwarg('num_channels', 3), 223 | architecture = kwarg('architecture', 'resnet'), 224 | channel_base = kwarg('fmap_base', 16384) * 2, 225 | channel_max = kwarg('fmap_max', 512), 226 | num_fp16_res = kwarg('num_fp16_res', 0), 227 | conv_clamp = kwarg('conv_clamp', None), 228 | cmap_dim = kwarg('mapping_fmaps', None), 229 | block_kwargs = dnnlib.EasyDict( 230 | activation = kwarg('nonlinearity', 'lrelu'), 231 | resample_filter = kwarg('resample_kernel', [1,3,3,1]), 232 | freeze_layers = kwarg('freeze_layers', 0), 233 | ), 234 | mapping_kwargs = dnnlib.EasyDict( 235 | num_layers = kwarg('mapping_layers', 0), 236 | embed_features = kwarg('mapping_fmaps', None), 237 | layer_features = kwarg('mapping_fmaps', None), 238 | activation = kwarg('nonlinearity', 'lrelu'), 239 | lr_multiplier = kwarg('mapping_lrmul', 0.1), 240 | ), 241 | epilogue_kwargs = dnnlib.EasyDict( 242 | mbstd_group_size = kwarg('mbstd_group_size', None), 243 | mbstd_num_channels = kwarg('mbstd_num_features', 1), 244 | activation = kwarg('nonlinearity', 'lrelu'), 245 | ), 246 | ) 247 | 248 | # Check for unknown kwargs. 249 | kwarg('structure') 250 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) 251 | if len(unknown_kwargs) > 0: 252 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) 253 | 254 | # Collect params. 255 | tf_params = _collect_tf_params(tf_D) 256 | for name, value in list(tf_params.items()): 257 | match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) 258 | if match: 259 | r = kwargs.img_resolution // (2 ** int(match.group(1))) 260 | tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value 261 | kwargs.architecture = 'orig' 262 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') 263 | 264 | # Convert params. 265 | from training import networks 266 | D = networks.Discriminator(**kwargs).eval().requires_grad_(False) 267 | # pylint: disable=unnecessary-lambda 268 | _populate_module_params(D, 269 | r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), 270 | r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'], 271 | r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1), 272 | r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], 273 | r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), 274 | r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(), 275 | r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'], 276 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(), 277 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'], 278 | r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), 279 | r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'], 280 | r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(), 281 | r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'], 282 | r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(), 283 | r'b4\.out\.bias', lambda: tf_params[f'Output/bias'], 284 | r'.*\.resample_filter', None, 285 | ) 286 | return D 287 | 288 | #---------------------------------------------------------------------------- 289 | 290 | @click.command() 291 | @click.option('--source', help='Input pickle', required=True, metavar='PATH') 292 | @click.option('--dest', help='Output pickle', required=True, metavar='PATH') 293 | @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True) 294 | def convert_network_pickle(source, dest, force_fp16): 295 | """Convert legacy network pickle into the native PyTorch format. 296 | 297 | The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. 298 | It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. 299 | 300 | Example: 301 | 302 | \b 303 | python legacy.py \\ 304 | --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ 305 | --dest=stylegan2-cat-config-f.pkl 306 | """ 307 | print(f'Loading "{source}"...') 308 | with dnnlib.util.open_url(source) as f: 309 | data = load_network_pkl(f, force_fp16=force_fp16) 310 | print(f'Saving "{dest}"...') 311 | with open(dest, 'wb') as f: 312 | pickle.dump(data, f) 313 | print('Done.') 314 | 315 | #---------------------------------------------------------------------------- 316 | 317 | if __name__ == "__main__": 318 | convert_network_pickle() # pylint: disable=no-value-for-parameter 319 | 320 | #---------------------------------------------------------------------------- 321 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio>=3.28.1 2 | Pydantic 3 | opencv-python 4 | scipy>=1.7.3 5 | ninja==1.11.1 6 | lpips 7 | 8 | -------------------------------------------------------------------------------- /resources/Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/Teaser.png -------------------------------------------------------------------------------- /resources/comparison_diffusion_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/comparison_diffusion_1.png -------------------------------------------------------------------------------- /resources/comparison_diffusion_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/comparison_diffusion_2.png -------------------------------------------------------------------------------- /resources/comparison_gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/comparison_gan.png -------------------------------------------------------------------------------- /resources/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/fig1.png -------------------------------------------------------------------------------- /resources/logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/logo2.png -------------------------------------------------------------------------------- /resources/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0 auto; 3 | font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif; 4 | font-size: 18px; 5 | } 6 | 7 | .container { 8 | margin: 0 auto; 9 | width: 100%; 10 | max-width: 1100px; 11 | text-align: center; 12 | display: block; 13 | } 14 | 15 | .title { 16 | font-size: 36px; 17 | margin-top: 20px; 18 | width: 90%; 19 | } 20 | 21 | .venue { 22 | font-size: 22px; 23 | margin-top: 20px; 24 | width: 90%; 25 | } 26 | 27 | .author { 28 | width: 95%; 29 | max-width: 300px; 30 | font-size: 20px; 31 | } 32 | 33 | .affiliation { 34 | font-size: 20px; 35 | width: 95%; 36 | max-width: 450px; 37 | } 38 | 39 | .links { 40 | font-size: 22px; 41 | width: 95%; 42 | max-width: 150px; 43 | } 44 | 45 | .video-container { 46 | position: relative; 47 | overflow: hidden; 48 | width: 80%; 49 | padding-top: 45%; 50 | /* Formula for 16:9 Aspect Ratio: width * 9 / 16 */ 51 | } 52 | 53 | .video-container iframe { 54 | position: absolute; 55 | top: 0; 56 | left: 0; 57 | bottom: 0; 58 | right: 0; 59 | width: 100%; 60 | height: 100%; 61 | } 62 | 63 | .paper-thumbnail { 64 | margin: 0 auto; 65 | width: 40%; 66 | max-width: 250px; 67 | display: inline-block; 68 | vertical-align: top; 69 | padding: 2% 10% 4% 0; 70 | } 71 | 72 | .paper-info { 73 | width: 45%; 74 | display: inline-block; 75 | vertical-align: top; 76 | } 77 | 78 | @media (max-width: 999px) { 79 | .paper-thumbnail { 80 | width: 60%; 81 | } 82 | 83 | .paper-info { 84 | width: 80%; 85 | } 86 | } 87 | 88 | 89 | p { 90 | text-align: left; 91 | margin: 0 auto; 92 | margin-bottom: 10px; 93 | } 94 | 95 | h1 { 96 | font-weight: 300; 97 | text-align: center; 98 | } 99 | 100 | h2 { 101 | text-align: center; 102 | } 103 | 104 | h3 { 105 | text-align: left; 106 | } 107 | 108 | h4 { 109 | text-align: left; 110 | } 111 | 112 | h5 { 113 | text-align: left; 114 | } 115 | 116 | div { 117 | display: inline-block; 118 | } 119 | 120 | hr { 121 | border: 0; 122 | height: 1px; 123 | background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0)); 124 | width: 90%; 125 | } 126 | 127 | pre { 128 | overflow-x: auto; 129 | text-align: left; 130 | border: 1px solid grey; 131 | border-radius: 3px; 132 | background: #eeeeee; 133 | padding: 5px 5px 5px 10px; 134 | line-height: 1.2; 135 | white-space: pre-wrap; 136 | } 137 | 138 | pre code { 139 | text-align: left; 140 | word-wrap: normal; 141 | white-space: pre-wrap; 142 | font-size: 14px; 143 | } 144 | 145 | a:link, 146 | a:visited { 147 | color: #1367a7; 148 | text-decoration: none; 149 | } 150 | 151 | a:hover { 152 | color: #208799; 153 | } 154 | 155 | .layered-paper-big { 156 | /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ 157 | box-shadow: 158 | 0px 0px 1px 1px rgba(0, 0, 0, 0.35), 159 | /* The top layer shadow */ 160 | 5px 5px 0 0px #fff, 161 | /* The second layer */ 162 | 5px 5px 1px 1px rgba(0, 0, 0, 0.35), 163 | /* The second layer shadow */ 164 | 10px 10px 0 0px #fff, 165 | /* The third layer */ 166 | 10px 10px 1px 1px rgba(0, 0, 0, 0.35), 167 | /* The third layer shadow */ 168 | 15px 15px 0 0px #fff, 169 | /* The fourth layer */ 170 | 15px 15px 1px 1px rgba(0, 0, 0, 0.35), 171 | /* The fourth layer shadow */ 172 | 20px 20px 0 0px #fff, 173 | /* The fifth layer */ 174 | 20px 20px 1px 1px rgba(0, 0, 0, 0.35), 175 | /* The fifth layer shadow */ 176 | 25px 25px 0 0px #fff, 177 | /* The fifth layer */ 178 | 25px 25px 1px 1px rgba(0, 0, 0, 0.35); 179 | /* The fifth layer shadow */ 180 | margin-left: 10px; 181 | margin-right: 45px; 182 | } -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to suppress known warnings in torch.jit.trace(). 68 | 69 | class suppress_tracer_warnings(warnings.catch_warnings): 70 | def __enter__(self): 71 | super().__enter__() 72 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning) 73 | return self 74 | 75 | #---------------------------------------------------------------------------- 76 | # Assert that the shape of a tensor matches the given list of integers. 77 | # None indicates that the size of a dimension is allowed to vary. 78 | # Performs symbolic assertion when used in torch.jit.trace(). 79 | 80 | def assert_shape(tensor, ref_shape): 81 | if tensor.ndim != len(ref_shape): 82 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 83 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 84 | if ref_size is None: 85 | pass 86 | elif isinstance(ref_size, torch.Tensor): 87 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 88 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 89 | elif isinstance(size, torch.Tensor): 90 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 91 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 92 | elif size != ref_size: 93 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 94 | 95 | #---------------------------------------------------------------------------- 96 | # Function decorator that calls torch.autograd.profiler.record_function(). 97 | 98 | def profiled_function(fn): 99 | def decorator(*args, **kwargs): 100 | with torch.autograd.profiler.record_function(fn.__name__): 101 | return fn(*args, **kwargs) 102 | decorator.__name__ = fn.__name__ 103 | return decorator 104 | 105 | #---------------------------------------------------------------------------- 106 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 107 | # indefinitely, shuffling items as it goes. 108 | 109 | class InfiniteSampler(torch.utils.data.Sampler): 110 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 111 | assert len(dataset) > 0 112 | assert num_replicas > 0 113 | assert 0 <= rank < num_replicas 114 | assert 0 <= window_size <= 1 115 | super().__init__(dataset) 116 | self.dataset = dataset 117 | self.rank = rank 118 | self.num_replicas = num_replicas 119 | self.shuffle = shuffle 120 | self.seed = seed 121 | self.window_size = window_size 122 | 123 | def __iter__(self): 124 | order = np.arange(len(self.dataset)) 125 | rnd = None 126 | window = 0 127 | if self.shuffle: 128 | rnd = np.random.RandomState(self.seed) 129 | rnd.shuffle(order) 130 | window = int(np.rint(order.size * self.window_size)) 131 | 132 | idx = 0 133 | while True: 134 | i = idx % order.size 135 | if idx % self.num_replicas == self.rank: 136 | yield order[i] 137 | if window >= 2: 138 | j = (i - rnd.randint(window)) % order.size 139 | order[i], order[j] = order[j], order[i] 140 | idx += 1 141 | 142 | #---------------------------------------------------------------------------- 143 | # Utilities for operating with torch.nn.Module parameters and buffers. 144 | 145 | def params_and_buffers(module): 146 | assert isinstance(module, torch.nn.Module) 147 | return list(module.parameters()) + list(module.buffers()) 148 | 149 | def named_params_and_buffers(module): 150 | assert isinstance(module, torch.nn.Module) 151 | return list(module.named_parameters()) + list(module.named_buffers()) 152 | 153 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 154 | assert isinstance(src_module, torch.nn.Module) 155 | assert isinstance(dst_module, torch.nn.Module) 156 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} 157 | for name, tensor in named_params_and_buffers(dst_module): 158 | assert (name in src_tensors) or (not require_all) 159 | if name in src_tensors: 160 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 161 | 162 | #---------------------------------------------------------------------------- 163 | # Context manager for easily enabling/disabling DistributedDataParallel 164 | # synchronization. 165 | 166 | @contextlib.contextmanager 167 | def ddp_sync(module, sync): 168 | assert isinstance(module, torch.nn.Module) 169 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 170 | yield 171 | else: 172 | with module.no_sync(): 173 | yield 174 | 175 | #---------------------------------------------------------------------------- 176 | # Check DistributedDataParallel consistency across processes. 177 | 178 | def check_ddp_consistency(module, ignore_regex=None): 179 | assert isinstance(module, torch.nn.Module) 180 | for name, tensor in named_params_and_buffers(module): 181 | fullname = type(module).__name__ + '.' + name 182 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 183 | continue 184 | tensor = tensor.detach() 185 | other = tensor.clone() 186 | torch.distributed.broadcast(tensor=other, src=0) 187 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname 188 | 189 | #---------------------------------------------------------------------------- 190 | # Print summary table of module hierarchy. 191 | 192 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 193 | assert isinstance(module, torch.nn.Module) 194 | assert not isinstance(module, torch.jit.ScriptModule) 195 | assert isinstance(inputs, (tuple, list)) 196 | 197 | # Register hooks. 198 | entries = [] 199 | nesting = [0] 200 | def pre_hook(_mod, _inputs): 201 | nesting[0] += 1 202 | def post_hook(mod, _inputs, outputs): 203 | nesting[0] -= 1 204 | if nesting[0] <= max_nesting: 205 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 206 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 207 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 208 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 209 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 210 | 211 | # Run module. 212 | outputs = module(*inputs) 213 | for hook in hooks: 214 | hook.remove() 215 | 216 | # Identify unique outputs, parameters, and buffers. 217 | tensors_seen = set() 218 | for e in entries: 219 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 220 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 221 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 222 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 223 | 224 | # Filter out redundant entries. 225 | if skip_redundant: 226 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 227 | 228 | # Construct table. 229 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 230 | rows += [['---'] * len(rows[0])] 231 | param_total = 0 232 | buffer_total = 0 233 | submodule_names = {mod: name for name, mod in module.named_modules()} 234 | for e in entries: 235 | name = '' if e.mod is module else submodule_names[e.mod] 236 | param_size = sum(t.numel() for t in e.unique_params) 237 | buffer_size = sum(t.numel() for t in e.unique_buffers) 238 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] 239 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 240 | rows += [[ 241 | name + (':0' if len(e.outputs) >= 2 else ''), 242 | str(param_size) if param_size else '-', 243 | str(buffer_size) if buffer_size else '-', 244 | (output_shapes + ['-'])[0], 245 | (output_dtypes + ['-'])[0], 246 | ]] 247 | for idx in range(1, len(e.outputs)): 248 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 249 | param_total += param_size 250 | buffer_total += buffer_size 251 | rows += [['---'] * len(rows[0])] 252 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 253 | 254 | # Print table. 255 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 256 | print() 257 | for row in rows: 258 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 259 | print() 260 | return outputs 261 | 262 | #---------------------------------------------------------------------------- 263 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | import traceback 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | activation_funcs = { 24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 33 | } 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | _inited = False 38 | _plugin = None 39 | _null_tensor = torch.empty([0]) 40 | 41 | def _init(): 42 | global _inited, _plugin 43 | if not _inited: 44 | _inited = True 45 | sources = ['bias_act.cpp', 'bias_act.cu'] 46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 47 | try: 48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 49 | except: 50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 51 | return _plugin is not None 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 56 | r"""Fused bias and activation function. 57 | 58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 59 | and scales the result by `gain`. Each of the steps is optional. In most cases, 60 | the fused op is considerably more efficient than performing the same calculation 61 | using standard PyTorch ops. It supports first and second order gradients, 62 | but not third order gradients. 63 | 64 | Args: 65 | x: Input activation tensor. Can be of any shape. 66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 67 | as `x`. The shape must be known, and it must match the dimension of `x` 68 | corresponding to `dim`. 69 | dim: The dimension in `x` corresponding to the elements of `b`. 70 | The value of `dim` is ignored if `b` is not specified. 71 | act: Name of the activation function to evaluate, or `"linear"` to disable. 72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 73 | See `activation_funcs` for a full list. `None` is not allowed. 74 | alpha: Shape parameter for the activation function, or `None` to use the default. 75 | gain: Scaling factor for the output tensor, or `None` to use default. 76 | See `activation_funcs` for the default scaling of each activation function. 77 | If unsure, consider specifying 1. 78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 79 | the clamping (default). 80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 81 | 82 | Returns: 83 | Tensor of the same shape and datatype as `x`. 84 | """ 85 | assert isinstance(x, torch.Tensor) 86 | assert impl in ['ref', 'cuda'] 87 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | @misc.profiled_function 94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 96 | """ 97 | assert isinstance(x, torch.Tensor) 98 | assert clamp is None or clamp >= 0 99 | spec = activation_funcs[act] 100 | alpha = float(alpha if alpha is not None else spec.def_alpha) 101 | gain = float(gain if gain is not None else spec.def_gain) 102 | clamp = float(clamp if clamp is not None else -1) 103 | 104 | # Add bias. 105 | if b is not None: 106 | assert isinstance(b, torch.Tensor) and b.ndim == 1 107 | assert 0 <= dim < x.ndim 108 | assert b.shape[0] == x.shape[dim] 109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 110 | 111 | # Evaluate activation function. 112 | alpha = float(alpha) 113 | x = spec.func(x, alpha=alpha) 114 | 115 | # Scale by gain. 116 | gain = float(gain) 117 | if gain != 1: 118 | x = x * gain 119 | 120 | # Clamp. 121 | if clamp >= 0: 122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 123 | return x 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | _bias_act_cuda_cache = dict() 128 | 129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 130 | """Fast CUDA implementation of `bias_act()` using custom ops. 131 | """ 132 | # Parse arguments. 133 | assert clamp is None or clamp >= 0 134 | spec = activation_funcs[act] 135 | alpha = float(alpha if alpha is not None else spec.def_alpha) 136 | gain = float(gain if gain is not None else spec.def_gain) 137 | clamp = float(clamp if clamp is not None else -1) 138 | 139 | # Lookup from cache. 140 | key = (dim, act, alpha, gain, clamp) 141 | if key in _bias_act_cuda_cache: 142 | return _bias_act_cuda_cache[key] 143 | 144 | # Forward op. 145 | class BiasActCuda(torch.autograd.Function): 146 | @staticmethod 147 | def forward(ctx, x, b): # pylint: disable=arguments-differ 148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 149 | x = x.contiguous(memory_format=ctx.memory_format) 150 | b = b.contiguous() if b is not None else _null_tensor 151 | y = x 152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 154 | ctx.save_for_backward( 155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 157 | y if 'y' in spec.ref else _null_tensor) 158 | return y 159 | 160 | @staticmethod 161 | def backward(ctx, dy): # pylint: disable=arguments-differ 162 | dy = dy.contiguous(memory_format=ctx.memory_format) 163 | x, b, y = ctx.saved_tensors 164 | dx = None 165 | db = None 166 | 167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 168 | dx = dy 169 | if act != 'linear' or gain != 1 or clamp >= 0: 170 | dx = BiasActCudaGrad.apply(dy, x, b, y) 171 | 172 | if ctx.needs_input_grad[1]: 173 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 174 | 175 | return dx, db 176 | 177 | # Backward op. 178 | class BiasActCudaGrad(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 183 | ctx.save_for_backward( 184 | dy if spec.has_2nd_grad else _null_tensor, 185 | x, b, y) 186 | return dx 187 | 188 | @staticmethod 189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 191 | dy, x, b, y = ctx.saved_tensors 192 | d_dy = None 193 | d_x = None 194 | d_b = None 195 | d_y = None 196 | 197 | if ctx.needs_input_grad[0]: 198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 199 | 200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 202 | 203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 205 | 206 | return d_dy, d_x, d_b, d_y 207 | 208 | # Add to cache. 209 | _bias_act_cuda_cache[key] = BiasActCuda 210 | return BiasActCuda 211 | 212 | #---------------------------------------------------------------------------- 213 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import warnings 13 | import contextlib 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | 25 | @contextlib.contextmanager 26 | def no_weight_gradients(): 27 | global weight_gradients_disabled 28 | old = weight_gradients_disabled 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 54 | return True 55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 56 | return False 57 | 58 | def _tuple_of_ints(xs, ndim): 59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 60 | assert len(xs) == ndim 61 | assert all(isinstance(x, int) for x in xs) 62 | return xs 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | _conv2d_gradfix_cache = dict() 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | if not transpose: 112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 113 | else: # transpose 114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 115 | ctx.save_for_backward(input, weight) 116 | return output 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | input, weight = ctx.saved_tensors 121 | grad_input = None 122 | grad_weight = None 123 | grad_bias = None 124 | 125 | if ctx.needs_input_grad[0]: 126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 128 | assert grad_input.shape == input.shape 129 | 130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 131 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 132 | assert grad_weight.shape == weight_shape 133 | 134 | if ctx.needs_input_grad[2]: 135 | grad_bias = grad_output.sum([0, 2, 3]) 136 | 137 | return grad_input, grad_weight, grad_bias 138 | 139 | # Gradient with respect to the weights. 140 | class Conv2dGradWeight(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, grad_output, input): 143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 146 | assert grad_weight.shape == weight_shape 147 | ctx.save_for_backward(grad_output, input) 148 | return grad_weight 149 | 150 | @staticmethod 151 | def backward(ctx, grad2_grad_weight): 152 | grad_output, input = ctx.saved_tensors 153 | grad2_grad_output = None 154 | grad2_input = None 155 | 156 | if ctx.needs_input_grad[0]: 157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 158 | assert grad2_grad_output.shape == grad_output.shape 159 | 160 | if ctx.needs_input_grad[1]: 161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 163 | assert grad2_input.shape == input.shape 164 | 165 | return grad2_grad_output, grad2_input 166 | 167 | _conv2d_gradfix_cache[key] = Conv2d 168 | return Conv2d 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | w = w.flip([2, 3]) 37 | 38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 42 | if out_channels <= 4 and groups == 1: 43 | in_shape = x.shape 44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 46 | else: 47 | x = x.to(memory_format=torch.contiguous_format) 48 | w = w.to(memory_format=torch.contiguous_format) 49 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 50 | return x.to(memory_format=torch.channels_last) 51 | 52 | # Otherwise => execute using conv2d_gradfix. 53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 54 | return op(x, w, stride=stride, padding=padding, groups=groups) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | @misc.profiled_function 59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 60 | r"""2D convolution with optional up/downsampling. 61 | 62 | Padding is performed only once at the beginning, not between the operations. 63 | 64 | Args: 65 | x: Input tensor of shape 66 | `[batch_size, in_channels, in_height, in_width]`. 67 | w: Weight tensor of shape 68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 70 | calling upfirdn2d.setup_filter(). None = identity (default). 71 | up: Integer upsampling factor (default: 1). 72 | down: Integer downsampling factor (default: 1). 73 | padding: Padding with respect to the upsampled image. Can be a single number 74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 75 | (default: 0). 76 | groups: Split input channels into N groups (default: 1). 77 | flip_weight: False = convolution, True = correlation (default: True). 78 | flip_filter: False = convolution, True = correlation (default: False). 79 | 80 | Returns: 81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 82 | """ 83 | # Validate arguments. 84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 87 | assert isinstance(up, int) and (up >= 1) 88 | assert isinstance(down, int) and (down >= 1) 89 | assert isinstance(groups, int) and (groups >= 1) 90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 91 | fw, fh = _get_filter_size(f) 92 | px0, px1, py0, py1 = _parse_padding(padding) 93 | 94 | # Adjust padding to account for up/downsampling. 95 | if up > 1: 96 | px0 += (fw + up - 1) // 2 97 | px1 += (fw - up) // 2 98 | py0 += (fh + up - 1) // 2 99 | py1 += (fh - up) // 2 100 | if down > 1: 101 | px0 += (fw - down + 1) // 2 102 | px1 += (fw - down) // 2 103 | py0 += (fh - down + 1) // 2 104 | py1 += (fh - down) // 2 105 | 106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 107 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 110 | return x 111 | 112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 113 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 116 | return x 117 | 118 | # Fast path: downsampling only => use strided convolution. 119 | if down > 1 and up == 1: 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 122 | return x 123 | 124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 125 | if up > 1: 126 | if groups == 1: 127 | w = w.transpose(0, 1) 128 | else: 129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 130 | w = w.transpose(1, 2) 131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 132 | px0 -= kw - 1 133 | px1 -= kw - up 134 | py0 -= kh - 1 135 | py1 -= kh - up 136 | pxt = max(min(-px0, -px1), 0) 137 | pyt = max(min(-py0, -py1), 0) 138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 140 | if down > 1: 141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 142 | return x 143 | 144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 145 | if up == 1 and down == 1: 146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 148 | 149 | # Fallback: Generic reference implementation. 150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 152 | if down > 1: 153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 154 | return x 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def grid_sample(input, grid): 28 | if _should_use_custom_op(): 29 | return _GridSample2dForward.apply(input, grid) 30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _should_use_custom_op(): 35 | if not enabled: 36 | return False 37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 38 | return True 39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 40 | return False 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | class _GridSample2dForward(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input, grid): 47 | assert input.ndim == 4 48 | assert grid.ndim == 4 49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 50 | ctx.save_for_backward(input, grid) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, grid = ctx.saved_tensors 56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 57 | return grad_input, grad_grid 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class _GridSample2dBackward(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, grad_output, input, grid): 64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import numpy as np 11 | import zipfile 12 | import PIL.Image 13 | import json 14 | import torch 15 | import dnnlib 16 | 17 | try: 18 | import pyspng 19 | except ImportError: 20 | pyspng = None 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | class Dataset(torch.utils.data.Dataset): 25 | def __init__(self, 26 | name, # Name of the dataset. 27 | raw_shape, # Shape of the raw image data (NCHW). 28 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 29 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 30 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 31 | random_seed = 0, # Random seed to use when applying max_size. 32 | ): 33 | self._name = name 34 | self._raw_shape = list(raw_shape) 35 | self._use_labels = use_labels 36 | self._raw_labels = None 37 | self._label_shape = None 38 | 39 | # Apply max_size. 40 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 41 | if (max_size is not None) and (self._raw_idx.size > max_size): 42 | np.random.RandomState(random_seed).shuffle(self._raw_idx) 43 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 44 | 45 | # Apply xflip. 46 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 47 | if xflip: 48 | self._raw_idx = np.tile(self._raw_idx, 2) 49 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 50 | 51 | def _get_raw_labels(self): 52 | if self._raw_labels is None: 53 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 54 | if self._raw_labels is None: 55 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 56 | assert isinstance(self._raw_labels, np.ndarray) 57 | assert self._raw_labels.shape[0] == self._raw_shape[0] 58 | assert self._raw_labels.dtype in [np.float32, np.int64] 59 | if self._raw_labels.dtype == np.int64: 60 | assert self._raw_labels.ndim == 1 61 | assert np.all(self._raw_labels >= 0) 62 | return self._raw_labels 63 | 64 | def close(self): # to be overridden by subclass 65 | pass 66 | 67 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 68 | raise NotImplementedError 69 | 70 | def _load_raw_labels(self): # to be overridden by subclass 71 | raise NotImplementedError 72 | 73 | def __getstate__(self): 74 | return dict(self.__dict__, _raw_labels=None) 75 | 76 | def __del__(self): 77 | try: 78 | self.close() 79 | except: 80 | pass 81 | 82 | def __len__(self): 83 | return self._raw_idx.size 84 | 85 | def __getitem__(self, idx): 86 | image = self._load_raw_image(self._raw_idx[idx]) 87 | assert isinstance(image, np.ndarray) 88 | assert list(image.shape) == self.image_shape 89 | assert image.dtype == np.uint8 90 | if self._xflip[idx]: 91 | assert image.ndim == 3 # CHW 92 | image = image[:, :, ::-1] 93 | return image.copy(), self.get_label(idx) 94 | 95 | def get_label(self, idx): 96 | label = self._get_raw_labels()[self._raw_idx[idx]] 97 | if label.dtype == np.int64: 98 | onehot = np.zeros(self.label_shape, dtype=np.float32) 99 | onehot[label] = 1 100 | label = onehot 101 | return label.copy() 102 | 103 | def get_details(self, idx): 104 | d = dnnlib.EasyDict() 105 | d.raw_idx = int(self._raw_idx[idx]) 106 | d.xflip = (int(self._xflip[idx]) != 0) 107 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 108 | return d 109 | 110 | @property 111 | def name(self): 112 | return self._name 113 | 114 | @property 115 | def image_shape(self): 116 | return list(self._raw_shape[1:]) 117 | 118 | @property 119 | def num_channels(self): 120 | assert len(self.image_shape) == 3 # CHW 121 | return self.image_shape[0] 122 | 123 | @property 124 | def resolution(self): 125 | assert len(self.image_shape) == 3 # CHW 126 | assert self.image_shape[1] == self.image_shape[2] 127 | return self.image_shape[1] 128 | 129 | @property 130 | def label_shape(self): 131 | if self._label_shape is None: 132 | raw_labels = self._get_raw_labels() 133 | if raw_labels.dtype == np.int64: 134 | self._label_shape = [int(np.max(raw_labels)) + 1] 135 | else: 136 | self._label_shape = raw_labels.shape[1:] 137 | return list(self._label_shape) 138 | 139 | @property 140 | def label_dim(self): 141 | assert len(self.label_shape) == 1 142 | return self.label_shape[0] 143 | 144 | @property 145 | def has_labels(self): 146 | return any(x != 0 for x in self.label_shape) 147 | 148 | @property 149 | def has_onehot_labels(self): 150 | return self._get_raw_labels().dtype == np.int64 151 | 152 | #---------------------------------------------------------------------------- 153 | 154 | class ImageFolderDataset(Dataset): 155 | def __init__(self, 156 | path, # Path to directory or zip. 157 | resolution = None, # Ensure specific resolution, None = highest available. 158 | **super_kwargs, # Additional arguments for the Dataset base class. 159 | ): 160 | self._path = path 161 | self._zipfile = None 162 | 163 | if os.path.isdir(self._path): 164 | self._type = 'dir' 165 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 166 | elif self._file_ext(self._path) == '.zip': 167 | self._type = 'zip' 168 | self._all_fnames = set(self._get_zipfile().namelist()) 169 | else: 170 | raise IOError('Path must point to a directory or zip') 171 | 172 | PIL.Image.init() 173 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 174 | if len(self._image_fnames) == 0: 175 | raise IOError('No image files found in the specified path') 176 | 177 | name = os.path.splitext(os.path.basename(self._path))[0] 178 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 179 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 180 | raise IOError('Image files do not match the specified resolution') 181 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 182 | 183 | @staticmethod 184 | def _file_ext(fname): 185 | return os.path.splitext(fname)[1].lower() 186 | 187 | def _get_zipfile(self): 188 | assert self._type == 'zip' 189 | if self._zipfile is None: 190 | self._zipfile = zipfile.ZipFile(self._path) 191 | return self._zipfile 192 | 193 | def _open_file(self, fname): 194 | if self._type == 'dir': 195 | return open(os.path.join(self._path, fname), 'rb') 196 | if self._type == 'zip': 197 | return self._get_zipfile().open(fname, 'r') 198 | return None 199 | 200 | def close(self): 201 | try: 202 | if self._zipfile is not None: 203 | self._zipfile.close() 204 | finally: 205 | self._zipfile = None 206 | 207 | def __getstate__(self): 208 | return dict(super().__getstate__(), _zipfile=None) 209 | 210 | def _load_raw_image(self, raw_idx): 211 | fname = self._image_fnames[raw_idx] 212 | with self._open_file(fname) as f: 213 | if pyspng is not None and self._file_ext(fname) == '.png': 214 | image = pyspng.load(f.read()) 215 | else: 216 | image = np.array(PIL.Image.open(f)) 217 | if image.ndim == 2: 218 | image = image[:, :, np.newaxis] # HW => HWC 219 | image = image.transpose(2, 0, 1) # HWC => CHW 220 | return image 221 | 222 | def _load_raw_labels(self): 223 | fname = 'dataset.json' 224 | if fname not in self._all_fnames: 225 | return None 226 | with self._open_file(fname) as f: 227 | labels = json.load(f)['labels'] 228 | if labels is None: 229 | return None 230 | labels = dict(labels) 231 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 232 | labels = np.array(labels) 233 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 234 | return labels 235 | 236 | #---------------------------------------------------------------------------- 237 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import torch 11 | from torch_utils import training_stats 12 | from torch_utils import misc 13 | from torch_utils.ops import conv2d_gradfix 14 | 15 | #---------------------------------------------------------------------------- 16 | 17 | class Loss: 18 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass 19 | raise NotImplementedError() 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | class StyleGAN2Loss(Loss): 24 | def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2): 25 | super().__init__() 26 | self.device = device 27 | self.G_mapping = G_mapping 28 | self.G_synthesis = G_synthesis 29 | self.D = D 30 | self.augment_pipe = augment_pipe 31 | self.style_mixing_prob = style_mixing_prob 32 | self.r1_gamma = r1_gamma 33 | self.pl_batch_shrink = pl_batch_shrink 34 | self.pl_decay = pl_decay 35 | self.pl_weight = pl_weight 36 | self.pl_mean = torch.zeros([], device=device) 37 | 38 | def run_G(self, z, c, sync): 39 | with misc.ddp_sync(self.G_mapping, sync): 40 | ws = self.G_mapping(z, c) 41 | if self.style_mixing_prob > 0: 42 | with torch.autograd.profiler.record_function('style_mixing'): 43 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) 44 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) 45 | ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:] 46 | with misc.ddp_sync(self.G_synthesis, sync): 47 | img = self.G_synthesis(ws) 48 | return img, ws 49 | 50 | def run_D(self, img, c, sync): 51 | if self.augment_pipe is not None: 52 | img = self.augment_pipe(img) 53 | with misc.ddp_sync(self.D, sync): 54 | logits = self.D(img, c) 55 | return logits 56 | 57 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): 58 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] 59 | do_Gmain = (phase in ['Gmain', 'Gboth']) 60 | do_Dmain = (phase in ['Dmain', 'Dboth']) 61 | do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0) 62 | do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0) 63 | 64 | # Gmain: Maximize logits for generated images. 65 | if do_Gmain: 66 | with torch.autograd.profiler.record_function('Gmain_forward'): 67 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl. 68 | gen_logits = self.run_D(gen_img, gen_c, sync=False) 69 | training_stats.report('Loss/scores/fake', gen_logits) 70 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 71 | loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) 72 | training_stats.report('Loss/G/loss', loss_Gmain) 73 | with torch.autograd.profiler.record_function('Gmain_backward'): 74 | loss_Gmain.mean().mul(gain).backward() 75 | 76 | # Gpl: Apply path length regularization. 77 | if do_Gpl: 78 | with torch.autograd.profiler.record_function('Gpl_forward'): 79 | batch_size = gen_z.shape[0] // self.pl_batch_shrink 80 | gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync) 81 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) 82 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(): 83 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] 84 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() 85 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) 86 | self.pl_mean.copy_(pl_mean.detach()) 87 | pl_penalty = (pl_lengths - pl_mean).square() 88 | training_stats.report('Loss/pl_penalty', pl_penalty) 89 | loss_Gpl = pl_penalty * self.pl_weight 90 | training_stats.report('Loss/G/reg', loss_Gpl) 91 | with torch.autograd.profiler.record_function('Gpl_backward'): 92 | (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward() 93 | 94 | # Dmain: Minimize logits for generated images. 95 | loss_Dgen = 0 96 | if do_Dmain: 97 | with torch.autograd.profiler.record_function('Dgen_forward'): 98 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False) 99 | gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal. 100 | training_stats.report('Loss/scores/fake', gen_logits) 101 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 102 | loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) 103 | with torch.autograd.profiler.record_function('Dgen_backward'): 104 | loss_Dgen.mean().mul(gain).backward() 105 | 106 | # Dmain: Maximize logits for real images. 107 | # Dr1: Apply R1 regularization. 108 | if do_Dmain or do_Dr1: 109 | name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1' 110 | with torch.autograd.profiler.record_function(name + '_forward'): 111 | real_img_tmp = real_img.detach().requires_grad_(do_Dr1) 112 | real_logits = self.run_D(real_img_tmp, real_c, sync=sync) 113 | training_stats.report('Loss/scores/real', real_logits) 114 | training_stats.report('Loss/signs/real', real_logits.sign()) 115 | 116 | loss_Dreal = 0 117 | if do_Dmain: 118 | loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) 119 | training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) 120 | 121 | loss_Dr1 = 0 122 | if do_Dr1: 123 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): 124 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] 125 | r1_penalty = r1_grads.square().sum([1,2,3]) 126 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2) 127 | training_stats.report('Loss/r1_penalty', r1_penalty) 128 | training_stats.report('Loss/D/reg', loss_Dr1) 129 | 130 | with torch.autograd.profiler.record_function(name + '_backward'): 131 | (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward() 132 | 133 | #---------------------------------------------------------------------------- 134 | --------------------------------------------------------------------------------