├── .gitignore
├── FreeDrag--diffusion--version
├── __pycache__
│ └── drag_pipeline.cpython-38.pyc
├── drag_pipeline.py
├── drag_ui.py
├── environment.yaml
├── local_pretrained_models
│ └── dummy.txt
├── lora
│ ├── lora_ckpt
│ │ └── dummy.txt
│ ├── train_dreambooth_lora.py
│ └── train_lora.sh
├── lora_tmp
│ └── pytorch_lora_weights.bin
├── readme.txt
└── utils
│ ├── __pycache__
│ ├── attn_utils.cpython-38.pyc
│ ├── drag_utils.cpython-38.pyc
│ ├── freedrag_utils.cpython-38.pyc
│ ├── freeu_utils.cpython-38.pyc
│ ├── lora_utils.cpython-310.pyc
│ ├── lora_utils.cpython-38.pyc
│ └── ui_utils.cpython-38.pyc
│ ├── attn_utils.py
│ ├── freedrag_utils.py
│ ├── freeu_utils.py
│ ├── lora_utils.py
│ └── ui_utils.py
├── FreeDrag_gradio.py
├── README.md
├── arial.ttf
├── dnnlib
├── __init__.py
└── util.py
├── download_models.sh
├── functions.py
├── legacy.py
├── requirements.txt
├── resources
├── Teaser.png
├── comparison_diffusion_1.png
├── comparison_diffusion_2.png
├── comparison_gan.png
├── fig1.png
├── logo2.png
└── style.css
├── torch_utils
├── __init__.py
├── custom_ops.py
├── misc.py
├── ops
│ ├── __init__.py
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ └── upfirdn2d.py
├── persistence.py
└── training_stats.py
└── training
├── __init__.py
├── augment.py
├── dataset.py
├── loss.py
└── networks.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/__pycache__/drag_pipeline.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/__pycache__/drag_pipeline.cpython-38.pyc
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/drag_ui.py:
--------------------------------------------------------------------------------
1 | # The diffusion-based implementation of FreeDrag is based on DragDiffusion(https://arxiv.org/abs/2306.14435)
2 |
3 | import os
4 | import gradio as gr
5 |
6 | from utils.ui_utils import get_points, undo_points
7 | from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag
8 | from utils.ui_utils import clear_all_gen, store_img_gen, gen_img, run_drag_gen
9 |
10 | LENGTH=480 # length of the square area displaying/editing images
11 |
12 | with gr.Blocks() as demo:
13 | # layout definition
14 | with gr.Row():
15 | gr.Markdown("""
16 | # Official Implementation of [FreeDrag--diffusion](https://arxiv.org/abs/2307.04684)
17 | """)
18 |
19 | # UI components for editing real images
20 | with gr.Tab(label="Editing Real Image"):
21 | mask = gr.State(value=None) # store mask
22 | selected_points = gr.State([]) # store points
23 | original_image = gr.State(value=None) # store original input image
24 | with gr.Row():
25 | with gr.Column():
26 | gr.Markdown("""
Draw Mask
""")
27 | canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
28 | show_label=True, height=LENGTH, width=LENGTH) # for mask painting
29 | train_lora_button = gr.Button("Train LoRA")
30 | with gr.Column():
31 | gr.Markdown("""Click Points
""")
32 | input_image = gr.Image(type="numpy", label="Click Points",
33 | show_label=True, height=LENGTH, width=LENGTH) # for points clicking
34 | undo_button = gr.Button("Undo point")
35 | with gr.Column():
36 | gr.Markdown("""Editing Results
""")
37 | output_image = gr.Image(type="numpy", label="Editing Results",
38 | show_label=True, height=LENGTH, width=LENGTH)
39 | with gr.Row():
40 | run_button = gr.Button("Run")
41 | clear_all_button = gr.Button("Clear All")
42 |
43 | # general parameters
44 | with gr.Row():
45 | prompt = gr.Textbox(label="Prompt")
46 | lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path")
47 | lora_status_bar = gr.Textbox(label="display LoRA training status")
48 |
49 | # algorithm specific parameters
50 | with gr.Tab("Drag Config"):
51 | with gr.Row():
52 | max_step = gr.Number(
53 | value=300,
54 | label="max steps",
55 | info="Number of maximum dragging steps.",
56 | precision=0)
57 | lam = gr.Number(value=10, label="lam", info="regularization strength on unmasked areas")
58 | l_expected = gr.Number(value=1, label="l_expected", info="the expected initial loss for each dragging")
59 | d_max = gr.Number(value=5, label="d_max", info="the max motion distance for each dragging")
60 | # n_actual_inference_step = gr.Number(value=40, label="optimize latent step", precision=0)
61 | inversion_strength = gr.Slider(0, 1.0,
62 | value=0.7,
63 | label="inversion strength",
64 | info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
65 | latent_lr = gr.Number(value=0.01, label="latent lr")
66 | start_step = gr.Number(value=0, label="start_step", precision=0, visible=False)
67 | start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False)
68 |
69 | with gr.Tab("Base Model Config"):
70 | with gr.Row():
71 | local_models_dir = '/mnt/petrelfs/lingpengyang/DragDiffusion/local_pretrained_models'
72 | local_models_choice = \
73 | [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
74 | model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5",
75 | label="Diffusion Model Path",
76 | choices=[
77 | "runwayml/stable-diffusion-v1-5",
78 | ] + local_models_choice
79 | )
80 | vae_path = gr.Dropdown(value="default",
81 | label="VAE choice",
82 | choices=["default",
83 | "stabilityai/sd-vae-ft-mse"] + local_models_choice
84 | )
85 |
86 | with gr.Tab("LoRA Parameters"):
87 | with gr.Row():
88 | lora_step = gr.Number(value=200, label="LoRA training steps", precision=0)
89 | lora_lr = gr.Number(value=0.0002, label="LoRA learning rate")
90 | lora_batch_size = gr.Number(value=4, label="LoRA batch size", precision=0)
91 | lora_rank = gr.Number(value=16, label="LoRA rank", precision=0)
92 |
93 | # UI components for editing generated images
94 | with gr.Tab(label="Editing Generated Image"):
95 | mask_gen = gr.State(value=None) # store mask
96 | selected_points_gen = gr.State([]) # store points
97 | original_image_gen = gr.State(value=None) # store the diffusion-generated image
98 | intermediate_latents_gen = gr.State(value=None) # store the intermediate diffusion latent during generation
99 | with gr.Row():
100 | with gr.Column():
101 | gr.Markdown("""Draw Mask
""")
102 | canvas_gen = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
103 | show_label=True, height=LENGTH, width=LENGTH) # for mask painting
104 | gen_img_button = gr.Button("Generate Image")
105 | with gr.Column():
106 | gr.Markdown("""Click Points
""")
107 | input_image_gen = gr.Image(type="numpy", label="Click Points",
108 | show_label=True, height=LENGTH, width=LENGTH) # for points clicking
109 | undo_button_gen = gr.Button("Undo point")
110 | with gr.Column():
111 | gr.Markdown("""Editing Results
""")
112 | output_image_gen = gr.Image(type="numpy", label="Editing Results",
113 | show_label=True, height=LENGTH, width=LENGTH)
114 | with gr.Row():
115 | run_button_gen = gr.Button("Run")
116 | clear_all_button_gen = gr.Button("Clear All")
117 |
118 | # general parameters
119 | with gr.Row():
120 | pos_prompt_gen = gr.Textbox(label="Positive Prompt")
121 | neg_prompt_gen = gr.Textbox(label="Negative Prompt")
122 |
123 | with gr.Tab("Generation Config"):
124 | with gr.Row():
125 | local_models_dir = '/mnt/petrelfs/lingpengyang/DragDiffusion/local_pretrained_models'
126 | local_models_choice = \
127 | [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
128 | model_path_gen = gr.Dropdown(value="runwayml/stable-diffusion-v1-5",
129 | label="Diffusion Model Path",
130 | choices=[
131 | "runwayml/stable-diffusion-v1-5",
132 | "gsdf/Counterfeit-V2.5",
133 | "emilianJR/majicMIX_realistic",
134 | "SG161222/Realistic_Vision_V2.0",
135 | "stablediffusionapi/interiordesignsuperm",
136 | "stablediffusionapi/dvarch",
137 | ] + local_models_choice
138 | )
139 | vae_path_gen = gr.Dropdown(value="default",
140 | label="VAE choice",
141 | choices=["default",
142 | "stabilityai/sd-vae-ft-mse"] + local_models_choice
143 | )
144 | lora_path_gen = gr.Textbox(value="", label="LoRA path")
145 | gen_seed = gr.Number(value=65536, label="Generation Seed", precision=0)
146 | height = gr.Number(value=512, label="Height", precision=0)
147 | width = gr.Number(value=512, label="Width", precision=0)
148 | guidance_scale = gr.Number(value=7.5, label="CFG Scale")
149 | scheduler_name_gen = gr.Dropdown(
150 | value="DDIM",
151 | label="Scheduler",
152 | choices=[
153 | "DDIM",
154 | "DPM++2M",
155 | "DPM++2M_karras"
156 | ]
157 | )
158 | n_inference_step_gen = gr.Number(value=50, label="Total Sampling Steps", precision=0)
159 |
160 | with gr.Tab("FreeU Parameters"):
161 | with gr.Row():
162 | b1_gen = gr.Slider(label='b1',
163 | info='1st stage backbone factor',
164 | minimum=1,
165 | maximum=1.6,
166 | step=0.05,
167 | value=1.1)
168 | b2_gen = gr.Slider(label='b2',
169 | info='2nd stage backbone factor',
170 | minimum=1,
171 | maximum=1.6,
172 | step=0.05,
173 | value=1.1)
174 | s1_gen = gr.Slider(label='s1',
175 | info='1st stage skip factor',
176 | minimum=0,
177 | maximum=1,
178 | step=0.05,
179 | value=0.8)
180 | s2_gen = gr.Slider(label='s2',
181 | info='2nd stage skip factor',
182 | minimum=0,
183 | maximum=1,
184 | step=0.05,
185 | value=0.8)
186 |
187 | with gr.Tab(label="Drag Config"):
188 | with gr.Row():
189 | max_step_gen = gr.Number(
190 | value=300,
191 | label="Number of Max Dragging step",
192 | info="Number of Max Dragging step.",
193 | precision=0)
194 | lam_gen = gr.Number(value=10, label="lam", info="regularization strength on unmasked areas")
195 | l_expected_gen = gr.Number(value=1, label="l_expected", info="the expected initial loss for each dragging")
196 | d_max_gen = gr.Number(value=5, label="d_max", info="the max motion distance for each dragging")
197 |
198 | # n_actual_inference_step_gen = gr.Number(value=40, label="optimize latent step", precision=0)
199 | inversion_strength_gen = gr.Slider(0, 1.0,
200 | value=0.7,
201 | label="Inversion Strength",
202 | info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
203 | latent_lr_gen = gr.Number(value=0.01, label="latent lr")
204 | start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False)
205 | start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False)
206 |
207 | # event definition
208 | # event for dragging user-input real image
209 | canvas.edit(
210 | store_img,
211 | [canvas],
212 | [original_image, selected_points, input_image, mask]
213 | )
214 | input_image.select(
215 | get_points,
216 | [input_image, selected_points],
217 | [input_image],
218 | )
219 | undo_button.click(
220 | undo_points,
221 | [original_image, mask],
222 | [input_image, selected_points]
223 | )
224 | train_lora_button.click(
225 | train_lora_interface,
226 | [original_image,
227 | prompt,
228 | model_path,
229 | vae_path,
230 | lora_path,
231 | lora_step,
232 | lora_lr,
233 | lora_batch_size,
234 | lora_rank],
235 | [lora_status_bar]
236 | )
237 | run_button.click(
238 | run_drag,
239 | [original_image,
240 | input_image,
241 | mask,
242 | prompt,
243 | selected_points,
244 | inversion_strength,
245 | lam,
246 | l_expected,
247 | d_max,
248 | latent_lr,
249 | max_step,
250 | model_path,
251 | vae_path,
252 | lora_path,
253 | start_step,
254 | start_layer,
255 | ],
256 | [output_image]
257 | )
258 | clear_all_button.click(
259 | clear_all,
260 | [gr.Number(value=LENGTH, visible=False, precision=0)],
261 | [canvas,
262 | input_image,
263 | output_image,
264 | selected_points,
265 | original_image,
266 | mask]
267 | )
268 |
269 | # event for dragging generated image
270 | canvas_gen.edit(
271 | store_img_gen,
272 | [canvas_gen],
273 | [original_image_gen, selected_points_gen, input_image_gen, mask_gen]
274 | )
275 | input_image_gen.select(
276 | get_points,
277 | [input_image_gen, selected_points_gen],
278 | [input_image_gen],
279 | )
280 | gen_img_button.click(
281 | gen_img,
282 | [
283 | gr.Number(value=LENGTH, visible=False, precision=0),
284 | height,
285 | width,
286 | n_inference_step_gen,
287 | scheduler_name_gen,
288 | gen_seed,
289 | guidance_scale,
290 | pos_prompt_gen,
291 | neg_prompt_gen,
292 | model_path_gen,
293 | vae_path_gen,
294 | lora_path_gen,
295 | b1_gen,
296 | b2_gen,
297 | s1_gen,
298 | s2_gen,
299 | ],
300 | [canvas_gen, input_image_gen, output_image_gen, mask_gen, intermediate_latents_gen]
301 | )
302 | undo_button_gen.click(
303 | undo_points,
304 | [original_image_gen, mask_gen],
305 | [input_image_gen, selected_points_gen]
306 | )
307 | run_button_gen.click(
308 | run_drag_gen,
309 | [
310 | n_inference_step_gen,
311 | scheduler_name_gen,
312 | original_image_gen, # the original image generated by the diffusion model
313 | input_image_gen, # image with clicking, masking, etc.
314 | intermediate_latents_gen,
315 | guidance_scale,
316 | mask_gen,
317 | pos_prompt_gen,
318 | neg_prompt_gen,
319 | selected_points_gen,
320 | inversion_strength_gen,
321 | lam_gen,
322 | latent_lr_gen,
323 | max_step_gen,
324 | l_expected_gen,
325 | d_max_gen,
326 | model_path_gen,
327 | vae_path_gen,
328 | lora_path_gen,
329 | start_step_gen,
330 | start_layer_gen,
331 | b1_gen,
332 | b2_gen,
333 | s1_gen,
334 | s2_gen,
335 | ],
336 | [output_image_gen]
337 | )
338 | clear_all_button_gen.click(
339 | clear_all_gen,
340 | [gr.Number(value=LENGTH, visible=False, precision=0)],
341 | [canvas_gen,
342 | input_image_gen,
343 | output_image_gen,
344 | selected_points_gen,
345 | original_image_gen,
346 | mask_gen,
347 | intermediate_latents_gen,
348 | ]
349 | )
350 |
351 | demo.queue().launch(share=True, debug=True,server_name="0.0.0.0",server_port=15322)
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/environment.yaml:
--------------------------------------------------------------------------------
1 | name: freedragdif
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=22.3.1
8 | - cudatoolkit=11.7
9 | - pip:
10 | - torch==2.0.0
11 | - torchvision==0.15.1
12 | - gradio==3.41.1
13 | - pydantic==2.0.2
14 | - albumentations==1.3.0
15 | - opencv-contrib-python==4.3.0.36
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - pytorch-lightning==1.5.0
19 | - omegaconf==2.3.0
20 | - test-tube>=0.7.5
21 | - streamlit==1.12.1
22 | - einops==0.6.0
23 | - transformers==4.27.0
24 | - webdataset==0.2.5
25 | - kornia==0.6
26 | - open_clip_torch==2.16.0
27 | - invisible-watermark>=0.1.5
28 | - streamlit-drawable-canvas==0.8.0
29 | - torchmetrics==0.6.0
30 | - timm==0.6.12
31 | - addict==2.4.0
32 | - yapf==0.32.0
33 | - prettytable==3.6.0
34 | - safetensors==0.2.7
35 | - basicsr==1.4.2
36 | - accelerate==0.17.0
37 | - decord==0.6.0
38 | - diffusers==0.17.1
39 | - moviepy==1.0.3
40 | - opencv_python==4.7.0.68
41 | - Pillow==9.4.0
42 | - scikit_image==0.19.3
43 | - scipy==1.10.1
44 | - tensorboardX==2.6
45 | - tqdm==4.64.1
46 | - numpy==1.24.1
47 |
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/local_pretrained_models/dummy.txt:
--------------------------------------------------------------------------------
1 | You may put your pretrained model here.
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/lora/lora_ckpt/dummy.txt:
--------------------------------------------------------------------------------
1 | lora checkpoints will be saved in this folder
2 |
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/lora/train_lora.sh:
--------------------------------------------------------------------------------
1 | export SAMPLE_DIR="lora/samples/sculpture"
2 | export OUTPUT_DIR="lora/lora_ckpt/sculpture_lora"
3 |
4 | export MODEL_NAME="runwayml/stable-diffusion-v1-5"
5 | export LORA_RANK=16
6 |
7 | accelerate launch lora/train_dreambooth_lora.py \
8 | --pretrained_model_name_or_path=$MODEL_NAME \
9 | --instance_data_dir=$SAMPLE_DIR \
10 | --output_dir=$OUTPUT_DIR \
11 | --instance_prompt="a photo of a sculpture" \
12 | --resolution=512 \
13 | --train_batch_size=1 \
14 | --gradient_accumulation_steps=1 \
15 | --checkpointing_steps=100 \
16 | --learning_rate=2e-4 \
17 | --lr_scheduler="constant" \
18 | --lr_warmup_steps=0 \
19 | --max_train_steps=200 \
20 | --lora_rank=$LORA_RANK \
21 | --seed="0"
22 |
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/lora_tmp/pytorch_lora_weights.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/lora_tmp/pytorch_lora_weights.bin
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/readme.txt:
--------------------------------------------------------------------------------
1 | The diffusion-based FreeDrag is based on DragDiffusion(https://arxiv.org/abs/2306.14435) with permission.
2 |
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/utils/__pycache__/attn_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/attn_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/utils/__pycache__/drag_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/drag_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/utils/__pycache__/freedrag_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/freedrag_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/utils/__pycache__/freeu_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/freeu_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/utils/__pycache__/lora_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/lora_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/utils/__pycache__/lora_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/lora_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/utils/__pycache__/ui_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/FreeDrag--diffusion--version/utils/__pycache__/ui_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/utils/attn_utils.py:
--------------------------------------------------------------------------------
1 | # *************************************************************************
2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4 | # ytedance Inc..
5 | # *************************************************************************
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from einops import rearrange, repeat
12 |
13 |
14 | class AttentionBase:
15 |
16 | def __init__(self):
17 | self.cur_step = 0
18 | self.num_att_layers = -1
19 | self.cur_att_layer = 0
20 |
21 | def after_step(self):
22 | pass
23 |
24 | def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
25 | out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
26 | self.cur_att_layer += 1
27 | if self.cur_att_layer == self.num_att_layers:
28 | self.cur_att_layer = 0
29 | self.cur_step += 1
30 | # after step
31 | self.after_step()
32 | return out
33 |
34 | def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
35 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
36 | out = rearrange(out, 'b h n d -> b n (h d)')
37 | return out
38 |
39 | def reset(self):
40 | self.cur_step = 0
41 | self.cur_att_layer = 0
42 |
43 |
44 | class MutualSelfAttentionControl(AttentionBase):
45 |
46 | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5):
47 | """
48 | Mutual self-attention control for Stable-Diffusion model
49 | Args:
50 | start_step: the step to start mutual self-attention control
51 | start_layer: the layer to start mutual self-attention control
52 | layer_idx: list of the layers to apply mutual self-attention control
53 | step_idx: list the steps to apply mutual self-attention control
54 | total_steps: the total number of steps
55 | """
56 | super().__init__()
57 | self.total_steps = total_steps
58 | self.start_step = start_step
59 | self.start_layer = start_layer
60 | self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16))
61 | self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
62 | # store the guidance scale to decide whether there are unconditional branch
63 | self.guidance_scale = guidance_scale
64 | print("step_idx: ", self.step_idx)
65 | print("layer_idx: ", self.layer_idx)
66 |
67 | def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
68 | """
69 | Attention forward function
70 | """
71 | if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
72 | return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
73 |
74 | if self.guidance_scale > 1.0:
75 | qu, qc = q[0:2], q[2:4]
76 | ku, kc = k[0:2], k[2:4]
77 | vu, vc = v[0:2], v[2:4]
78 |
79 | # merge queries of source and target branch into one so we can use torch API
80 | qu = torch.cat([qu[0:1], qu[1:2]], dim=2)
81 | qc = torch.cat([qc[0:1], qc[1:2]], dim=2)
82 |
83 | out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
84 | out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch
85 | out_u = rearrange(out_u, 'b h n d -> b n (h d)')
86 |
87 | out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
88 | out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch
89 | out_c = rearrange(out_c, 'b h n d -> b n (h d)')
90 |
91 | out = torch.cat([out_u, out_c], dim=0)
92 | else:
93 | q = torch.cat([q[0:1], q[1:2]], dim=2)
94 | out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
95 | out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch
96 | out = rearrange(out, 'b h n d -> b n (h d)')
97 | return out
98 |
99 | # forward function for default attention processor
100 | # modified from __call__ function of AttnProcessor in diffusers
101 | def override_attn_proc_forward(attn, editor, place_in_unet):
102 | def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
103 | """
104 | The attention is similar to the original implementation of LDM CrossAttention class
105 | except adding some modifications on the attention
106 | """
107 | if encoder_hidden_states is not None:
108 | context = encoder_hidden_states
109 | if attention_mask is not None:
110 | mask = attention_mask
111 |
112 | to_out = attn.to_out
113 | if isinstance(to_out, nn.modules.container.ModuleList):
114 | to_out = attn.to_out[0]
115 | else:
116 | to_out = attn.to_out
117 |
118 | h = attn.heads
119 | q = attn.to_q(x)
120 | is_cross = context is not None
121 | context = context if is_cross else x
122 | k = attn.to_k(context)
123 | v = attn.to_v(context)
124 |
125 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
126 |
127 | # the only difference
128 | out = editor(
129 | q, k, v, is_cross, place_in_unet,
130 | attn.heads, scale=attn.scale)
131 |
132 | return to_out(out)
133 |
134 | return forward
135 |
136 | # forward function for lora attention processor
137 | # modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1
138 | def override_lora_attn_proc_forward(attn, editor, place_in_unet):
139 | def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora_scale=1.0):
140 | residual = hidden_states
141 | input_ndim = hidden_states.ndim
142 | is_cross = encoder_hidden_states is not None
143 |
144 | if input_ndim == 4:
145 | batch_size, channel, height, width = hidden_states.shape
146 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
147 |
148 | batch_size, sequence_length, _ = (
149 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
150 | )
151 |
152 | if attention_mask is not None:
153 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
154 | # scaled_dot_product_attention expects attention_mask shape to be
155 | # (batch, heads, source_length, target_length)
156 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
157 |
158 | if attn.group_norm is not None:
159 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
160 |
161 | query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states)
162 |
163 | if encoder_hidden_states is None:
164 | encoder_hidden_states = hidden_states
165 | elif attn.norm_cross:
166 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
167 |
168 | key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states)
169 | value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states)
170 |
171 | query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value))
172 |
173 | # the only difference
174 | hidden_states = editor(
175 | query, key, value, is_cross, place_in_unet,
176 | attn.heads, scale=attn.scale)
177 |
178 | # linear proj
179 | hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states)
180 | # dropout
181 | hidden_states = attn.to_out[1](hidden_states)
182 |
183 | if input_ndim == 4:
184 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
185 |
186 | if attn.residual_connection:
187 | hidden_states = hidden_states + residual
188 |
189 | hidden_states = hidden_states / attn.rescale_output_factor
190 |
191 | return hidden_states
192 |
193 | return forward
194 |
195 | def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'):
196 | """
197 | Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
198 | """
199 | def register_editor(net, count, place_in_unet):
200 | for name, subnet in net.named_children():
201 | if net.__class__.__name__ == 'Attention': # spatial Transformer layer
202 | if attn_processor == 'attn_proc':
203 | net.forward = override_attn_proc_forward(net, editor, place_in_unet)
204 | elif attn_processor == 'lora_attn_proc':
205 | net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet)
206 | else:
207 | raise NotImplementedError("not implemented")
208 | return count + 1
209 | elif hasattr(net, 'children'):
210 | count = register_editor(subnet, count, place_in_unet)
211 | return count
212 |
213 | cross_att_count = 0
214 | for net_name, net in model.unet.named_children():
215 | if "down" in net_name:
216 | cross_att_count += register_editor(net, 0, "down")
217 | elif "mid" in net_name:
218 | cross_att_count += register_editor(net, 0, "mid")
219 | elif "up" in net_name:
220 | cross_att_count += register_editor(net, 0, "up")
221 | editor.num_att_layers = cross_att_count
222 |
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/utils/freeu_utils.py:
--------------------------------------------------------------------------------
1 | # *************************************************************************
2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4 | # ytedance Inc..
5 | # *************************************************************************
6 |
7 | import torch
8 | import torch.fft as fft
9 | from diffusers.models.unet_2d_condition import logger
10 | from diffusers.utils import is_torch_version
11 | from typing import Any, Dict, List, Optional, Tuple, Union
12 |
13 |
14 | def isinstance_str(x: object, cls_name: str):
15 | """
16 | Checks whether x has any class *named* cls_name in its ancestry.
17 | Doesn't require access to the class's implementation.
18 |
19 | Useful for patching!
20 | """
21 |
22 | for _cls in x.__class__.__mro__:
23 | if _cls.__name__ == cls_name:
24 | return True
25 |
26 | return False
27 |
28 |
29 | def Fourier_filter(x, threshold, scale):
30 | dtype = x.dtype
31 | x = x.type(torch.float32)
32 | # FFT
33 | x_freq = fft.fftn(x, dim=(-2, -1))
34 | x_freq = fft.fftshift(x_freq, dim=(-2, -1))
35 |
36 | B, C, H, W = x_freq.shape
37 | mask = torch.ones((B, C, H, W)).cuda()
38 |
39 | crow, ccol = H // 2, W //2
40 | mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
41 | x_freq = x_freq * mask
42 |
43 | # IFFT
44 | x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
45 | x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
46 |
47 | x_filtered = x_filtered.type(dtype)
48 | return x_filtered
49 |
50 |
51 | def register_upblock2d(model):
52 | def up_forward(self):
53 | def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
54 | for resnet in self.resnets:
55 | # pop res hidden states
56 | res_hidden_states = res_hidden_states_tuple[-1]
57 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
58 | #print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
59 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
60 |
61 | if self.training and self.gradient_checkpointing:
62 |
63 | def create_custom_forward(module):
64 | def custom_forward(*inputs):
65 | return module(*inputs)
66 |
67 | return custom_forward
68 |
69 | if is_torch_version(">=", "1.11.0"):
70 | hidden_states = torch.utils.checkpoint.checkpoint(
71 | create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
72 | )
73 | else:
74 | hidden_states = torch.utils.checkpoint.checkpoint(
75 | create_custom_forward(resnet), hidden_states, temb
76 | )
77 | else:
78 | hidden_states = resnet(hidden_states, temb)
79 |
80 | if self.upsamplers is not None:
81 | for upsampler in self.upsamplers:
82 | hidden_states = upsampler(hidden_states, upsample_size)
83 |
84 | return hidden_states
85 |
86 | return forward
87 |
88 | for i, upsample_block in enumerate(model.unet.up_blocks):
89 | if isinstance_str(upsample_block, "UpBlock2D"):
90 | upsample_block.forward = up_forward(upsample_block)
91 |
92 |
93 | def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
94 | def up_forward(self):
95 | def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
96 | for resnet in self.resnets:
97 | # pop res hidden states
98 | res_hidden_states = res_hidden_states_tuple[-1]
99 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
100 | #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
101 |
102 | # --------------- FreeU code -----------------------
103 | # Only operate on the first two stages
104 | if hidden_states.shape[1] == 1280:
105 | hidden_states[:,:640] = hidden_states[:,:640] * self.b1
106 | res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
107 | if hidden_states.shape[1] == 640:
108 | hidden_states[:,:320] = hidden_states[:,:320] * self.b2
109 | res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
110 | # ---------------------------------------------------------
111 |
112 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
113 |
114 | if self.training and self.gradient_checkpointing:
115 |
116 | def create_custom_forward(module):
117 | def custom_forward(*inputs):
118 | return module(*inputs)
119 |
120 | return custom_forward
121 |
122 | if is_torch_version(">=", "1.11.0"):
123 | hidden_states = torch.utils.checkpoint.checkpoint(
124 | create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
125 | )
126 | else:
127 | hidden_states = torch.utils.checkpoint.checkpoint(
128 | create_custom_forward(resnet), hidden_states, temb
129 | )
130 | else:
131 | hidden_states = resnet(hidden_states, temb)
132 |
133 | if self.upsamplers is not None:
134 | for upsampler in self.upsamplers:
135 | hidden_states = upsampler(hidden_states, upsample_size)
136 |
137 | return hidden_states
138 |
139 | return forward
140 |
141 | for i, upsample_block in enumerate(model.unet.up_blocks):
142 | if isinstance_str(upsample_block, "UpBlock2D"):
143 | upsample_block.forward = up_forward(upsample_block)
144 | setattr(upsample_block, 'b1', b1)
145 | setattr(upsample_block, 'b2', b2)
146 | setattr(upsample_block, 's1', s1)
147 | setattr(upsample_block, 's2', s2)
148 |
149 |
150 | def register_crossattn_upblock2d(model):
151 | def up_forward(self):
152 | def forward(
153 | hidden_states: torch.FloatTensor,
154 | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
155 | temb: Optional[torch.FloatTensor] = None,
156 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
157 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
158 | upsample_size: Optional[int] = None,
159 | attention_mask: Optional[torch.FloatTensor] = None,
160 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
161 | ):
162 | for resnet, attn in zip(self.resnets, self.attentions):
163 | # pop res hidden states
164 | #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
165 | res_hidden_states = res_hidden_states_tuple[-1]
166 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
167 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
168 |
169 | if self.training and self.gradient_checkpointing:
170 |
171 | def create_custom_forward(module, return_dict=None):
172 | def custom_forward(*inputs):
173 | if return_dict is not None:
174 | return module(*inputs, return_dict=return_dict)
175 | else:
176 | return module(*inputs)
177 |
178 | return custom_forward
179 |
180 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
181 | hidden_states = torch.utils.checkpoint.checkpoint(
182 | create_custom_forward(resnet),
183 | hidden_states,
184 | temb,
185 | **ckpt_kwargs,
186 | )
187 | hidden_states = torch.utils.checkpoint.checkpoint(
188 | create_custom_forward(attn, return_dict=False),
189 | hidden_states,
190 | encoder_hidden_states,
191 | None, # timestep
192 | None, # class_labels
193 | cross_attention_kwargs,
194 | attention_mask,
195 | encoder_attention_mask,
196 | **ckpt_kwargs,
197 | )[0]
198 | else:
199 | hidden_states = resnet(hidden_states, temb)
200 | hidden_states = attn(
201 | hidden_states,
202 | encoder_hidden_states=encoder_hidden_states,
203 | cross_attention_kwargs=cross_attention_kwargs,
204 | attention_mask=attention_mask,
205 | encoder_attention_mask=encoder_attention_mask,
206 | return_dict=False,
207 | )[0]
208 |
209 | if self.upsamplers is not None:
210 | for upsampler in self.upsamplers:
211 | hidden_states = upsampler(hidden_states, upsample_size)
212 |
213 | return hidden_states
214 |
215 | return forward
216 |
217 | for i, upsample_block in enumerate(model.unet.up_blocks):
218 | if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
219 | upsample_block.forward = up_forward(upsample_block)
220 |
221 |
222 | def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
223 | def up_forward(self):
224 | def forward(
225 | hidden_states: torch.FloatTensor,
226 | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
227 | temb: Optional[torch.FloatTensor] = None,
228 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
229 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
230 | upsample_size: Optional[int] = None,
231 | attention_mask: Optional[torch.FloatTensor] = None,
232 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
233 | ):
234 | for resnet, attn in zip(self.resnets, self.attentions):
235 | # pop res hidden states
236 | #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
237 | res_hidden_states = res_hidden_states_tuple[-1]
238 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
239 |
240 | # --------------- FreeU code -----------------------
241 | # Only operate on the first two stages
242 | if hidden_states.shape[1] == 1280:
243 | hidden_states[:,:640] = hidden_states[:,:640] * self.b1
244 | res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
245 | if hidden_states.shape[1] == 640:
246 | hidden_states[:,:320] = hidden_states[:,:320] * self.b2
247 | res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
248 | # ---------------------------------------------------------
249 |
250 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
251 |
252 | if self.training and self.gradient_checkpointing:
253 |
254 | def create_custom_forward(module, return_dict=None):
255 | def custom_forward(*inputs):
256 | if return_dict is not None:
257 | return module(*inputs, return_dict=return_dict)
258 | else:
259 | return module(*inputs)
260 |
261 | return custom_forward
262 |
263 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
264 | hidden_states = torch.utils.checkpoint.checkpoint(
265 | create_custom_forward(resnet),
266 | hidden_states,
267 | temb,
268 | **ckpt_kwargs,
269 | )
270 | hidden_states = torch.utils.checkpoint.checkpoint(
271 | create_custom_forward(attn, return_dict=False),
272 | hidden_states,
273 | encoder_hidden_states,
274 | None, # timestep
275 | None, # class_labels
276 | cross_attention_kwargs,
277 | attention_mask,
278 | encoder_attention_mask,
279 | **ckpt_kwargs,
280 | )[0]
281 | else:
282 | hidden_states = resnet(hidden_states, temb)
283 | # hidden_states = attn(
284 | # hidden_states,
285 | # encoder_hidden_states=encoder_hidden_states,
286 | # cross_attention_kwargs=cross_attention_kwargs,
287 | # encoder_attention_mask=encoder_attention_mask,
288 | # return_dict=False,
289 | # )[0]
290 | hidden_states = attn(
291 | hidden_states,
292 | encoder_hidden_states=encoder_hidden_states,
293 | cross_attention_kwargs=cross_attention_kwargs,
294 | )[0]
295 |
296 | if self.upsamplers is not None:
297 | for upsampler in self.upsamplers:
298 | hidden_states = upsampler(hidden_states, upsample_size)
299 |
300 | return hidden_states
301 |
302 | return forward
303 |
304 | for i, upsample_block in enumerate(model.unet.up_blocks):
305 | if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
306 | upsample_block.forward = up_forward(upsample_block)
307 | setattr(upsample_block, 'b1', b1)
308 | setattr(upsample_block, 'b2', b2)
309 | setattr(upsample_block, 's1', s1)
310 | setattr(upsample_block, 's2', s2)
311 |
--------------------------------------------------------------------------------
/FreeDrag--diffusion--version/utils/lora_utils.py:
--------------------------------------------------------------------------------
1 | # *************************************************************************
2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4 | # ytedance Inc..
5 | # *************************************************************************
6 |
7 | from PIL import Image
8 | import os
9 | import numpy as np
10 | from einops import rearrange
11 | import torch
12 | import torch.nn.functional as F
13 | from torchvision import transforms
14 | from accelerate import Accelerator
15 | from accelerate.utils import set_seed
16 | from PIL import Image
17 |
18 | from transformers import AutoTokenizer, PretrainedConfig
19 |
20 | import diffusers
21 | from diffusers import (
22 | AutoencoderKL,
23 | DDPMScheduler,
24 | DiffusionPipeline,
25 | DPMSolverMultistepScheduler,
26 | StableDiffusionPipeline,
27 | UNet2DConditionModel,
28 | )
29 | from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
30 | from diffusers.models.attention_processor import (
31 | AttnAddedKVProcessor,
32 | AttnAddedKVProcessor2_0,
33 | LoRAAttnAddedKVProcessor,
34 | LoRAAttnProcessor,
35 | LoRAAttnProcessor2_0,
36 | SlicedAttnAddedKVProcessor,
37 | )
38 | from diffusers.optimization import get_scheduler
39 | from diffusers.utils import check_min_version
40 | from diffusers.utils.import_utils import is_xformers_available
41 |
42 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
43 | check_min_version("0.17.0")
44 |
45 |
46 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
47 | text_encoder_config = PretrainedConfig.from_pretrained(
48 | pretrained_model_name_or_path,
49 | subfolder="text_encoder",
50 | revision=revision,
51 | )
52 | model_class = text_encoder_config.architectures[0]
53 |
54 | if model_class == "CLIPTextModel":
55 | from transformers import CLIPTextModel
56 |
57 | return CLIPTextModel
58 | elif model_class == "RobertaSeriesModelWithTransformation":
59 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
60 |
61 | return RobertaSeriesModelWithTransformation
62 | elif model_class == "T5EncoderModel":
63 | from transformers import T5EncoderModel
64 |
65 | return T5EncoderModel
66 | else:
67 | raise ValueError(f"{model_class} is not supported.")
68 |
69 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
70 | if tokenizer_max_length is not None:
71 | max_length = tokenizer_max_length
72 | else:
73 | max_length = tokenizer.model_max_length
74 |
75 | text_inputs = tokenizer(
76 | prompt,
77 | truncation=True,
78 | padding="max_length",
79 | max_length=max_length,
80 | return_tensors="pt",
81 | )
82 |
83 | return text_inputs
84 |
85 | def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False):
86 | text_input_ids = input_ids.to(text_encoder.device)
87 |
88 | if text_encoder_use_attention_mask:
89 | attention_mask = attention_mask.to(text_encoder.device)
90 | else:
91 | attention_mask = None
92 |
93 | prompt_embeds = text_encoder(
94 | text_input_ids,
95 | attention_mask=attention_mask,
96 | )
97 | prompt_embeds = prompt_embeds[0]
98 |
99 | return prompt_embeds
100 |
101 | # model_path: path of the model
102 | # image: input image, have not been pre-processed
103 | # save_lora_path: the path to save the lora
104 | # prompt: the user input prompt
105 | # lora_step: number of lora training step
106 | # lora_lr: learning rate of lora training
107 | # lora_rank: the rank of lora
108 | # save_interval: the frequency of saving lora checkpoints
109 | def train_lora(image,
110 | prompt,
111 | model_path,
112 | vae_path,
113 | save_lora_path,
114 | lora_step,
115 | lora_lr,
116 | lora_batch_size,
117 | lora_rank,
118 | progress,
119 | save_interval=-1):
120 | # initialize accelerator
121 | accelerator = Accelerator(
122 | gradient_accumulation_steps=1,
123 | mixed_precision='fp16'
124 | )
125 | set_seed(0)
126 |
127 | # Load the tokenizer
128 | tokenizer = AutoTokenizer.from_pretrained(
129 | model_path,
130 | subfolder="tokenizer",
131 | revision=None,
132 | use_fast=False,
133 | )
134 | # initialize the model
135 | noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
136 | text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None)
137 | text_encoder = text_encoder_cls.from_pretrained(
138 | model_path, subfolder="text_encoder", revision=None
139 | )
140 | if vae_path == "default":
141 | vae = AutoencoderKL.from_pretrained(
142 | model_path, subfolder="vae", revision=None
143 | )
144 | else:
145 | vae = AutoencoderKL.from_pretrained(vae_path)
146 | unet = UNet2DConditionModel.from_pretrained(
147 | model_path, subfolder="unet", revision=None
148 | )
149 |
150 | # set device and dtype
151 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
152 |
153 | vae.requires_grad_(False)
154 | text_encoder.requires_grad_(False)
155 | unet.requires_grad_(False)
156 |
157 | unet.to(device, dtype=torch.float16)
158 | vae.to(device, dtype=torch.float16)
159 | text_encoder.to(device, dtype=torch.float16)
160 |
161 | # initialize UNet LoRA
162 | unet_lora_attn_procs = {}
163 | for name, attn_processor in unet.attn_processors.items():
164 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
165 | if name.startswith("mid_block"):
166 | hidden_size = unet.config.block_out_channels[-1]
167 | elif name.startswith("up_blocks"):
168 | block_id = int(name[len("up_blocks.")])
169 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
170 | elif name.startswith("down_blocks"):
171 | block_id = int(name[len("down_blocks.")])
172 | hidden_size = unet.config.block_out_channels[block_id]
173 | else:
174 | raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
175 |
176 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
177 | lora_attn_processor_class = LoRAAttnAddedKVProcessor
178 | else:
179 | lora_attn_processor_class = (
180 | LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
181 | )
182 | unet_lora_attn_procs[name] = lora_attn_processor_class(
183 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
184 | )
185 |
186 | unet.set_attn_processor(unet_lora_attn_procs)
187 | unet_lora_layers = AttnProcsLayers(unet.attn_processors)
188 |
189 | # Optimizer creation
190 | params_to_optimize = (unet_lora_layers.parameters())
191 | optimizer = torch.optim.AdamW(
192 | params_to_optimize,
193 | lr=lora_lr,
194 | betas=(0.9, 0.999),
195 | weight_decay=1e-2,
196 | eps=1e-08,
197 | )
198 |
199 | lr_scheduler = get_scheduler(
200 | "constant",
201 | optimizer=optimizer,
202 | num_warmup_steps=0,
203 | num_training_steps=lora_step,
204 | num_cycles=1,
205 | power=1.0,
206 | )
207 |
208 | # prepare accelerator
209 | unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
210 | optimizer = accelerator.prepare_optimizer(optimizer)
211 | lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
212 |
213 | # initialize text embeddings
214 | with torch.no_grad():
215 | text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None)
216 | text_embedding = encode_prompt(
217 | text_encoder,
218 | text_inputs.input_ids,
219 | text_inputs.attention_mask,
220 | text_encoder_use_attention_mask=False
221 | )
222 | text_embedding = text_embedding.repeat(lora_batch_size, 1, 1)
223 |
224 | # initialize latent distribution
225 | image_transforms = transforms.Compose(
226 | [
227 | transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
228 | transforms.RandomCrop(512),
229 | transforms.ToTensor(),
230 | transforms.Normalize([0.5], [0.5]),
231 | ]
232 | )
233 |
234 | for step in progress.tqdm(range(lora_step), desc="training LoRA"):
235 | unet.train()
236 | image_batch = []
237 | for _ in range(lora_batch_size):
238 | image_transformed = image_transforms(Image.fromarray(image)).to(device, dtype=torch.float16)
239 | image_transformed = image_transformed.unsqueeze(dim=0)
240 | image_batch.append(image_transformed)
241 |
242 | # repeat the image_transformed to enable multi-batch training
243 | image_batch = torch.cat(image_batch, dim=0)
244 |
245 | latents_dist = vae.encode(image_batch).latent_dist
246 | model_input = latents_dist.sample() * vae.config.scaling_factor
247 | # Sample noise that we'll add to the latents
248 | noise = torch.randn_like(model_input)
249 | bsz, channels, height, width = model_input.shape
250 | # Sample a random timestep for each image
251 | timesteps = torch.randint(
252 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
253 | )
254 | timesteps = timesteps.long()
255 |
256 | # Add noise to the model input according to the noise magnitude at each timestep
257 | # (this is the forward diffusion process)
258 | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
259 |
260 | # Predict the noise residual
261 | model_pred = unet(noisy_model_input, timesteps, text_embedding).sample
262 |
263 | # Get the target for loss depending on the prediction type
264 | if noise_scheduler.config.prediction_type == "epsilon":
265 | target = noise
266 | elif noise_scheduler.config.prediction_type == "v_prediction":
267 | target = noise_scheduler.get_velocity(model_input, noise, timesteps)
268 | else:
269 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
270 |
271 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
272 | accelerator.backward(loss)
273 | optimizer.step()
274 | lr_scheduler.step()
275 | optimizer.zero_grad()
276 |
277 | if save_interval > 0 and (step + 1) % save_interval == 0:
278 | save_lora_path_intermediate = os.path.join(save_lora_path, str(step+1))
279 | if not os.path.isdir(save_lora_path_intermediate):
280 | os.mkdir(save_lora_path_intermediate)
281 | # unet = unet.to(torch.float32)
282 | # unwrap_model is used to remove all special modules added when doing distributed training
283 | # so here, there is no need to call unwrap_model
284 | # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
285 | LoraLoaderMixin.save_lora_weights(
286 | save_directory=save_lora_path_intermediate,
287 | unet_lora_layers=unet_lora_layers,
288 | text_encoder_lora_layers=None,
289 | )
290 | # unet = unet.to(torch.float16)
291 |
292 | # save the trained lora
293 | # unet = unet.to(torch.float32)
294 | # unwrap_model is used to remove all special modules added when doing distributed training
295 | # so here, there is no need to call unwrap_model
296 | # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
297 | LoraLoaderMixin.save_lora_weights(
298 | save_directory=save_lora_path,
299 | unet_lora_layers=unet_lora_layers,
300 | text_encoder_lora_layers=None,
301 | )
302 |
303 | return
304 |
--------------------------------------------------------------------------------
/FreeDrag_gradio.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import torch
3 | import numpy as np
4 | from functions import to_image, draw_handle_target_points, free_drag, image_inversion, add_watermark_np
5 | import dnnlib
6 | from training import networks
7 | import legacy
8 | import cv2
9 |
10 | # export CUDA_LAUNCH_BLOCKING=1
11 | def load_model(model_name, device):
12 |
13 | path = './checkpoints/' + str(model_name)
14 | with dnnlib.util.open_url(path) as f:
15 | G = legacy.load_network_pkl(f)['G_ema'].to(device)
16 | G_copy = networks.Generator(z_dim=G.z_dim, c_dim= G.c_dim, w_dim =G.w_dim,
17 | img_resolution = G.img_resolution,
18 | img_channels = G.img_channels,
19 | mapping_kwargs = G.init_kwargs['mapping_kwargs'])
20 |
21 | G_copy.load_state_dict(G.state_dict())
22 | G_copy.to(device)
23 | del(G)
24 | for param in G_copy.parameters():
25 | param.requires_grad = False
26 | return G_copy, model_name
27 |
28 | def draw_mask(image,mask):
29 |
30 | image_mask = image*(1-mask) +mask*(0.7*image+0.3*255.0)
31 |
32 | return image_mask
33 |
34 |
35 | class ModelWrapper:
36 | def __init__(self, model,model_name):
37 | self.g = model
38 | self.name = model_name
39 | self.res = CKPT_SIZE[model_name][0]
40 | self.l = CKPT_SIZE[model_name][1]
41 | self.d = CKPT_SIZE[model_name][2]
42 |
43 |
44 | # model, points, mask, feature_size, train_layer_index,max_step, device,seed=2023,max_distance=3, d=0.5
45 | # img_show, current_target, step_number
46 | def on_drag(model, points, mask, max_iters,latent,sample_interval,l_expected,d_max,save_video,add_watermark):
47 |
48 | if len(points['handle']) == 0:
49 | raise gr.Error('You must select at least one handle point and target point.')
50 | if len(points['handle']) != len(points['target']):
51 | raise gr.Error('You have uncompleted handle points, try to selct a target point or undo the handle point.')
52 | max_iters = int(max_iters)
53 |
54 | handle_size = 128
55 | train_layer_index=6
56 | l_expected = torch.tensor(l_expected,device=latent.device)
57 | d_max = torch.tensor(d_max,device=latent.device)
58 | mask[mask>0] = 1
59 | global stop_flag
60 | stop_flag = False
61 | images_total = []
62 | for img_show, current_target, step_number,full_res, latent_optimized in free_drag(model.g,points,mask[:,:,0],handle_size, \
63 | train_layer_index,latent,max_iters,l_expected,d_max,sample_interval,device=latent.device):
64 | image = to_image(img_show)
65 |
66 | points['handle'] = [current_target[p,:].cpu().numpy().astype('int') for p in range(len(current_target[:,0]))]
67 | image_show = add_points_to_image(image, points, size=RES_TO_CLICK_SIZE[full_res],color="yellow")
68 |
69 | if np.any(mask[:,:,0]>0):
70 | image_show = draw_mask(image_show,mask)
71 | image_show = np.uint8(image_show)
72 |
73 | if add_watermark:
74 | image_show = add_watermark_np(np.array(image_show))[:,:,[0,1,2]]
75 | image_clear = add_watermark_np(np.array(image))[:,:,[0,1,2]]
76 | else:
77 | image_clear = image
78 |
79 | if save_video:
80 | images_total.append(image_show)
81 | yield (image_show, step_number, latent_optimized,image_clear,images_total,gr.Button.update(interactive=True))
82 |
83 | if stop_flag:
84 | break
85 |
86 | def add_points_to_image(image, points, size=5,color="red"):
87 | image = draw_handle_target_points(image, points['handle'], points['target'], size, color)
88 | return image
89 |
90 | def on_show_save():
91 | return gr.update(visible=True)
92 |
93 | def on_click(image, target_point, points, res, evt: gr.SelectData):
94 | if target_point:
95 | points['target'].append([evt.index[1], evt.index[0]])
96 | image = add_points_to_image(image, points, size=RES_TO_CLICK_SIZE[res])
97 | return image, not target_point
98 | points['handle'].append([evt.index[1], evt.index[0]])
99 | image = add_points_to_image(image, points, size=RES_TO_CLICK_SIZE[res])
100 | return image, not target_point
101 |
102 | def new_image(model,seed=-1):
103 | if seed == -1:
104 | seed = np.random.randint(1,1e6)
105 | z1 = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.g.z_dim)).to(device)
106 | label = torch.zeros([1, model.g.c_dim], device=device)
107 | ws_original= model.g.get_ws(z1,label,truncation_psi=0.7)
108 | _, img_show_original = model.g.synthesis(ws=ws_original,noise_mode='const')
109 |
110 | return to_image(img_show_original), to_image(img_show_original), ws_original, seed
111 |
112 | def new_model(model_name):
113 | model_load, _ = load_model(model_name, device)
114 | model = ModelWrapper(model_load,model_name)
115 |
116 | return model, model.res, model.l, model.d
117 |
118 | def reset_all(image,mask,add_watermark=0):
119 | points = {'target': [], 'handle': []}
120 | target_point = False
121 | mask = np.zeros_like(mask,dtype=np.uint8)
122 |
123 | return points, target_point, image, None,mask, add_watermark
124 |
125 | def add_mask(image_show,mask):
126 | image_show = draw_mask(image_show,mask)
127 | return image_show
128 |
129 | def update_mask(image,mask_show):
130 | mask = np.zeros_like(image)
131 | if mask_show != None and np.any(mask_show['mask'][:,:,0]>1):
132 | mask[mask_show['mask'][:,:,:3]>0] =1
133 | image_mask = add_mask(image,mask)
134 | return np.uint8(image_mask), mask
135 | else:
136 | return image, mask
137 |
138 | def on_select_mask_tab(image):
139 | return image
140 |
141 | def change_stop_state():
142 | global stop_flag
143 | stop_flag = True
144 |
145 | def save_video(imgs_show_list,frame):
146 | if len(imgs_show_list)>0:
147 | video_name = './process.mp4'
148 | fource = cv2.VideoWriter_fourcc(*'mp4v')
149 | full_res = imgs_show_list[0].shape[0]
150 | video_output = cv2.VideoWriter(video_name,fourcc=fource,fps=frame,frameSize = (full_res,full_res))
151 | for k in range(len(imgs_show_list)):
152 | video_output.write(imgs_show_list[k][:,:,::-1])
153 | video_output.release()
154 | return []
155 |
156 | CKPT_SIZE = {
157 | 'faces.pkl':[512, 0.3, 3],
158 | 'horses.pkl': [256, 0.3, 3],
159 | 'elephants.pkl': [512, 0.4, 4],
160 | 'lions.pkl':[512, 0.4, 4],
161 | 'dogs.pkl':[1024, 0.4, 4],
162 | 'bicycles.pkl':[256, 0.3, 3],
163 | 'giraffes.pkl':[512, 0.4, 4],
164 | 'cats.pkl':[512, 0.3, 3],
165 | 'cars.pkl':[512, 0.3, 3],
166 | 'churches.pkl':[256, 0.3, 3],
167 | 'metfaces.pkl':[1024, 0.3, 3],
168 | }
169 | RES_TO_CLICK_SIZE = {
170 | 1024: 10,
171 | 512: 5,
172 | 256: 3,
173 | }
174 |
175 | if torch.cuda.is_available():
176 | device = 'cuda'
177 | else:
178 | device = 'cpu'
179 |
180 | demo = gr.Blocks()
181 |
182 | with demo:
183 |
184 | points = gr.State({'target': [], 'handle': []})
185 | target_point = gr.State(False)
186 | state = gr.State({})
187 |
188 | gr.Markdown(
189 | """
190 | # **FreeDrag**
191 |
192 | Official implementation of [FreeDrag: Point Tracking is Not What You Need for Interactive Point-based Image Editing](https://github.com/LPengYang/FreeDrag)
193 |
194 |
195 | ## Parameter Description
196 | **max_step**: max number of optimization step
197 |
198 | **sample_interval**: the interval between sampled optimization step.
199 | This parameter only affects the visualization of intermediate results and does not have any impact on the final outcome.
200 | For high-resolution images(such as model of dog), a larger sample_interval can significantly accelerate the dragging process.
201 |
202 | **Eepected initial loss and Max distance**: In the current version, both of these values are empirically set for each model.
203 | Generally, for precise editing needs (e.g., merging eyes), smaller values are recommended, which may causes longer processing times.
204 | Users can set these values according to practical editing requirements. We are currently seeking an automated solution.
205 |
206 | **frame_rate**: the frame rate for saved video.
207 |
208 | ## Hints
209 | - Handle points (Blue): the point you want to drag.
210 | - Target points (Red): the destination you want to drag towards to.
211 | - **Localized points (Yellow)**: the localized points in sub-motion
212 | """,
213 | )
214 |
215 | with gr.Row():
216 | with gr.Column(scale=0.4):
217 | with gr.Accordion("Model"):
218 | with gr.Row():
219 | with gr.Column(min_width=100):
220 | seed = gr.Number(label='Seed',value=0)
221 | with gr.Column(min_width=100):
222 | button_new = gr.Button('Image Generate', variant='primary')
223 | button_rand = gr.Button('Rand Generate')
224 | model_name = gr.Dropdown(label="Model name",choices=list(CKPT_SIZE.keys()),value = list(CKPT_SIZE.keys())[0])
225 |
226 | with gr.Accordion('Optional Parameters'):
227 | with gr.Row():
228 | with gr.Column(min_width=100):
229 | max_step = gr.Number(label='Max step',value=2000)
230 | with gr.Column(min_width=100):
231 | sample_interval = gr.Number(label='Interval',value=5,info="Sampling interval")
232 |
233 | model_load, _ = load_model(model_name.value, device)
234 | model = gr.State(ModelWrapper(model_load,model_name.value))
235 | l_expected = gr.Slider(0.1,0.5,label='Expected initial loss for each sub-motion',value = model.value.l,step=0.05)
236 | d_max= gr.Slider(1.0,6.0,label='Max distance for each sub-motion (in the feature map)',value = model.value.d,step=0.5)
237 |
238 | res = gr.State(model.value.res)
239 | z1 = torch.from_numpy(np.random.RandomState(int(seed.value)).randn(1, model.value.g.z_dim)).to(device)
240 | label = torch.zeros([1, model.value.g.c_dim], device=device)
241 | ws_original= model.value.g.get_ws(z1,label,truncation_psi=0.7)
242 | latent = gr.State(ws_original)
243 | add_watermark = gr.State(torch.zeros(1,device=device))
244 |
245 | _, img_show_original = model.value.g.synthesis(ws=ws_original,noise_mode='const')
246 |
247 | with gr.Accordion('Video'):
248 | images_total = gr.State([])
249 | with gr.Row():
250 | with gr.Column(min_width=100):
251 | if_save_video = gr.Radio(["True","False"],value="False",label="if save video")
252 | with gr.Column(min_width=100):
253 | frame_rate = gr.Number(label="Frame rate",value=5)
254 | with gr.Row():
255 | with gr.Column(min_width=100):
256 | button_video = gr.Button('Save video', variant='primary')
257 |
258 | with gr.Accordion('Drag'):
259 |
260 | with gr.Row():
261 | with gr.Column(min_width=200):
262 | reset_btn = gr.Button('Reset points and mask')
263 | with gr.Row():
264 | with gr.Column(min_width=100):
265 | button_drag = gr.Button('Drag it', variant='primary')
266 | with gr.Column(min_width=100):
267 | button_stop = gr.Button('Stop')
268 |
269 | progress = gr.Number(value=0, label='Steps', interactive=False)
270 |
271 | with gr.Column(scale=0.53):
272 | with gr.Tabs() as Tabs:
273 | image_show = to_image(img_show_original)
274 | image_clear = gr.State(image_show)
275 | mask = gr.State(np.zeros_like(image_clear.value))
276 | with gr.Tab('Setup Handle Points', id='input') as imagetab:
277 | image = gr.Image(image_show).style(height=768, width=768)
278 | with gr.Tab('Draw a Mask', id='mask') as masktab:
279 | mask_show = gr.ImageMask(image_show).style(height=768, width=768)
280 |
281 | image.select(on_click, [image, target_point, points, res], [image, target_point]).then(on_show_save)
282 |
283 | image.upload(image_inversion,[image,res,model],[latent,image,image_clear,add_watermark]).then(reset_all,
284 | inputs=[image_clear,mask,add_watermark],outputs=[points,target_point,image,mask_show,mask,add_watermark])
285 |
286 | button_drag.click(on_drag, inputs=[model, points, mask, max_step,latent,sample_interval,l_expected,d_max,if_save_video,add_watermark], \
287 | outputs=[image, progress, latent, image_clear, images_total, button_stop])
288 | button_stop.click(change_stop_state)
289 |
290 | button_video.click(save_video,inputs=[images_total,frame_rate],outputs=[images_total])
291 | reset_btn.click(reset_all,inputs=[image_clear,mask,add_watermark],outputs= [points,target_point,image,mask_show,mask,add_watermark]).then(on_show_save)
292 |
293 | button_new.click(new_image, inputs = [model,seed],outputs = [image, image_clear, latent,seed]).then(reset_all,
294 | inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask,add_watermark])
295 |
296 | button_rand.click(new_image, inputs = [model],outputs = [image, image_clear, latent,seed]).then(reset_all,
297 | inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask,add_watermark])
298 |
299 | model_name.change(new_model,inputs=[model_name],outputs=[model,res,l_expected,d_max]).then \
300 | (new_image, inputs = [model,seed],outputs = [image, image_clear, latent,seed]).then \
301 | (reset_all,inputs=[image_clear,mask],outputs=[points,target_point,image,mask_show,mask,add_watermark])
302 |
303 | imagetab.select(update_mask,[image,mask_show],[image,mask])
304 | masktab.select(on_select_mask_tab, inputs=[image], outputs=[mask_show])
305 |
306 |
307 | if __name__ == "__main__":
308 |
309 | demo.queue(concurrency_count=3,max_size=20).launch(share=True)
310 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # FreeDrag: Feature Dragging for Reliable Point-based Image Editing
6 |
7 |
8 |
9 |
10 |
11 | ## Visualization
12 | [![]](https://user-images.githubusercontent.com/58554846/253733958-c97629a0-5928-476b-99f2-79d5f92762e7.mp4)
13 |
14 |
15 | ## Web Demo (online dragging editing in 11 different StyleGAN2 models)
16 | [](https://openxlab.org.cn/apps/detail/LPengYang/FreeDrag)
17 |
18 | Official implementation of **FreeDrag: Feature Dragging for Reliable Point-based Image Editing**.
19 | - *Authors*: Pengyang Ling*, [Lin Chen*](https://lin-chen.site), [Pan Zhang](https://panzhang0212.github.io/), Huaian Chen, Yi Jin, Jinjin Zheng,
20 | - *Institutes*: University of Science and Technology of China; Shanghai AI Laboratory
21 | - [[Paper]](https://arxiv.org/abs/2307.04684) [[Project Page]](https://lin-chen.site/projects/freedrag) [[Web Demo]](https://openxlab.org.cn/apps/detail/LPengYang/FreeDrag)
22 |
23 | This repo proposes FreeDrag, a novel interactive point-based image editing framework free of the laborious and unstable point tracking process🔥🔥🔥.
24 |
25 |
26 | ## Abstract
27 | To serve the intricate and varied demands of image editing, precise and flexible manipulation in image content is indispensable. Recently, Drag-based editing methods have gained impressive performance. However, these methods predominantly center on point dragging, resulting in two noteworthy drawbacks, namely "miss tracking", where difficulties arise in accurately tracking the predetermined handle points, and "ambiguous tracking", where tracked points are potentially positioned in wrong regions that closely resemble the handle points. To address the above issues, we propose **FreeDrag**, a feature dragging methodology designed to free the burden on point tracking. The **FreeDrag** incorporates two key designs, i.e., template feature via adaptive updating and line search with backtracking, the former improves the stability against drastic content change by elaborately controls feature updating scale after each dragging, while the latter alleviates the misguidance from similar points by actively restricting the search area in a line. These two technologies together contribute to a more stable semantic dragging with higher efficiency. Comprehensive experimental results substantiate that our approach significantly outperforms pre-existing methodologies, offering reliable point-based editing even in various complex scenarios.
28 |
29 |
30 |
31 |
32 |
33 |
34 | ## 📜 News
35 | [2024/03/06] [FreeDrag](https://arxiv.org/abs/2307.04684) is accepted by CVPR 2024.
36 |
37 | [2023/12/11] The updataed [FreeDrag](https://arxiv.org/abs/2307.04684) containing the implementations in both StyleGAN and Diffusion models is available now.
38 |
39 | [2023/12/8] FreeDrag based on diffusion model is available now, which support dragging editing in both real images and generated images.
40 |
41 | [2023/7/31] The web demo (StyleGAN) in [OpenXLab](https://openxlab.org.cn/apps/detail/LPengYang/FreeDrag) is available now.
42 |
43 | [2023/7/28] The function of real image editing is available now.
44 |
45 | [2023/7/15] Code of local demo is available now!💥
46 |
47 | [2023/7/11] The [paper](https://arxiv.org/abs/2307.04684) and [project page](https://lin-chen.site/projects/freedrag) are released!
48 |
49 | ## 💡 Highlights
50 | - [x] Local demo of FreeDrag
51 | - [x] Web demo of FreeDrag
52 | - [x] Diffusion-based FreeDrag
53 |
54 | ## 🛠️Usage
55 |
56 | First clone our repository
57 | ```
58 | git clone --depth=1 https://github.com/LPengYang/FreeDrag
59 | ```
60 | To create a new environment, please follow the requirements of [NVlabs/stylegan2-ada](https://github.com/NVlabs/stylegan2-ada-pytorch#requirements).
61 |
62 | **Notice:** It is observed that the errors (setting up PyTorch plugin “bias_act_plugin“... Failed or “upfirdn2d_plugin“... Failed) may appear in some devices, we hope these potential solutions ([1](https://blog.csdn.net/qq_15969343/article/details/129190607), [2](https://github.com/NVlabs/stylegan2-ada-pytorch/issues/155), [3](https://github.com/NVlabs/stylegan3/issues/124), [4](https://github.com/XingangPan/DragGAN/issues/106)) could be helpful in this case.
63 |
64 | Then install the additional requirements
65 |
66 | ```
67 | pip install -r requirements.txt
68 | ```
69 |
70 | Then download the pre-trained models of stylegan2
71 | ```
72 | bash download_models.sh
73 | ```
74 | **Notice:** The first model (face model) could be downloaded very slowly in some cases. In this case, it is suggested to restart the download (works sometimes) or directly download it from this [link](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/research/models/stylegan2/files), please download the correct model (ffhq-512×512) and renamed it as "faces.pkl" and manually put it in the created checkpoints file (after all the other models are downloaded).
75 |
76 | Finally initialize the gradio platform for interactive point-based manipulation
77 |
78 | ```
79 | CUDA_LAUNCH_BLOCKING=1 python FreeDrag_gradio.py
80 | ```
81 | You can also upload your images and then edit them. For a high-quality image inversion, it is suggested to make sure that the resolution and style (such as layout) of the uploaded images are consistent with the generated images of corresponding model. The resolution of different model is listed as follows:
82 |
83 | |Model|face|horse|elephant|lion|dog|bicycle|giraffe|cat|car|church|metface|
84 | |:----:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
85 | |Resolution|512|256|512|512|1024|256|512|512|512|256|1024|
86 |
87 | The proposed **FreeDragBench Dataset** is available on the [website](https://drive.google.com/file/d/1p2muR6aW6fqEGW8yTcHl86DCuUkwgtNY/view?usp=sharing).
88 |
89 | ## ❤️Acknowledgments
90 | - [DragGAN](https://github.com/XingangPan/DragGAN/)
91 | - [DragDiffusion](https://yujun-shi.github.io/projects/dragdiffusion.html)
92 | - [StyleGAN2](https://github.com/NVlabs/stylegan2-ada-pytorch)
93 |
94 | ## License
95 | All codes used or modified from [StyleGAN2](https://github.com/NVlabs/stylegan2-ada-pytorch) are under the [Nvidia Source Code License](https://github.com/NVlabs/stylegan3/blob/main/LICENSE.txt).
96 | The code related to the FreeDrag algorithm is only allowed for personal activity. The diffusion-based FreeDrag is implemented based on [DragDiffusion](https://yujun-shi.github.io/projects/dragdiffusion.html).
97 |
98 | ## ✒️ Citation
99 | If you find our work helpful for your research, please consider citing the following BibTeX entry.
100 | ```bibtex
101 | @inproceedings{ling2024freedrag,
102 | title={Freedrag: Feature dragging for reliable point-based image editing},
103 | author={Ling, Pengyang and Chen, Lin and Zhang, Pan and Chen, Huaian and Jin, Yi and Zheng, Jinjin},
104 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
105 | pages={6860--6870},
106 | year={2024}
107 | }
108 | ```
109 |
--------------------------------------------------------------------------------
/arial.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/arial.ttf
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | from .util import EasyDict, make_cache_dir_path
10 |
--------------------------------------------------------------------------------
/download_models.sh:
--------------------------------------------------------------------------------
1 | mkdir checkpoints
2 | cd checkpoints
3 | rm *
4 | wget 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl'
5 | mv stylegan2-ffhq-512x512.pkl faces.pkl
6 | curl -o cats.pkl https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqcat.pkl
7 | curl -o lions.pkl https://storage.googleapis.com/self-distilled-stylegan/lions_512_pytorch.pkl
8 | curl -o dogs.pkl https://storage.googleapis.com/self-distilled-stylegan/dogs_1024_pytorch.pkl
9 | curl -o horses.pkl https://storage.googleapis.com/self-distilled-stylegan/horses_256_pytorch.pkl
10 | curl -o elephants.pkl https://storage.googleapis.com/self-distilled-stylegan/elephants_512_pytorch.pkl
11 | curl -o bicycles.pkl https://storage.googleapis.com/self-distilled-stylegan/bicycles_256_pytorch.pkl
12 | curl -o giraffes.pkl https://storage.googleapis.com/self-distilled-stylegan/giraffes_512_pytorch.pkl
13 | curl -o cars.pkl http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl
14 | curl -o churches.pkl http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-church-config-f.pkl
15 | curl -o metfaces.pkl https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
16 |
--------------------------------------------------------------------------------
/legacy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import click
10 | import pickle
11 | import re
12 | import copy
13 | import numpy as np
14 | import torch
15 | import dnnlib
16 | from torch_utils import misc
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def load_network_pkl(f, force_fp16=False):
21 | data = _LegacyUnpickler(f).load()
22 |
23 | # Legacy TensorFlow pickle => convert.
24 | if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
25 | tf_G, tf_D, tf_Gs = data
26 | G = convert_tf_generator(tf_G)
27 | D = convert_tf_discriminator(tf_D)
28 | G_ema = convert_tf_generator(tf_Gs)
29 | data = dict(G=G, D=D, G_ema=G_ema)
30 |
31 | # Add missing fields.
32 | if 'training_set_kwargs' not in data:
33 | data['training_set_kwargs'] = None
34 | if 'augment_pipe' not in data:
35 | data['augment_pipe'] = None
36 |
37 | # Validate contents.
38 | assert isinstance(data['G'], torch.nn.Module)
39 | assert isinstance(data['D'], torch.nn.Module)
40 | assert isinstance(data['G_ema'], torch.nn.Module)
41 | assert isinstance(data['training_set_kwargs'], (dict, type(None)))
42 | assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
43 |
44 | # Force FP16.
45 | if force_fp16:
46 | for key in ['G', 'D', 'G_ema']:
47 | old = data[key]
48 | kwargs = copy.deepcopy(old.init_kwargs)
49 | if key.startswith('G'):
50 | kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
51 | kwargs.synthesis_kwargs.num_fp16_res = 4
52 | kwargs.synthesis_kwargs.conv_clamp = 256
53 | if key.startswith('D'):
54 | kwargs.num_fp16_res = 4
55 | kwargs.conv_clamp = 256
56 | if kwargs != old.init_kwargs:
57 | new = type(old)(**kwargs).eval().requires_grad_(False)
58 | misc.copy_params_and_buffers(old, new, require_all=True)
59 | data[key] = new
60 | return data
61 |
62 | #----------------------------------------------------------------------------
63 |
64 | class _TFNetworkStub(dnnlib.EasyDict):
65 | pass
66 |
67 | class _LegacyUnpickler(pickle.Unpickler):
68 | def find_class(self, module, name):
69 | if module == 'dnnlib.tflib.network' and name == 'Network':
70 | return _TFNetworkStub
71 | return super().find_class(module, name)
72 |
73 | #----------------------------------------------------------------------------
74 |
75 | def _collect_tf_params(tf_net):
76 | # pylint: disable=protected-access
77 | tf_params = dict()
78 | def recurse(prefix, tf_net):
79 | for name, value in tf_net.variables:
80 | tf_params[prefix + name] = value
81 | for name, comp in tf_net.components.items():
82 | recurse(prefix + name + '/', comp)
83 | recurse('', tf_net)
84 | return tf_params
85 |
86 | #----------------------------------------------------------------------------
87 |
88 | def _populate_module_params(module, *patterns):
89 | for name, tensor in misc.named_params_and_buffers(module):
90 | found = False
91 | value = None
92 | for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
93 | match = re.fullmatch(pattern, name)
94 | if match:
95 | found = True
96 | if value_fn is not None:
97 | value = value_fn(*match.groups())
98 | break
99 | try:
100 | assert found
101 | if value is not None:
102 | tensor.copy_(torch.from_numpy(np.array(value)))
103 | except:
104 | print(name, list(tensor.shape))
105 | raise
106 |
107 | #----------------------------------------------------------------------------
108 |
109 | def convert_tf_generator(tf_G):
110 | if tf_G.version < 4:
111 | raise ValueError('TensorFlow pickle version too low')
112 |
113 | # Collect kwargs.
114 | tf_kwargs = tf_G.static_kwargs
115 | known_kwargs = set()
116 | def kwarg(tf_name, default=None, none=None):
117 | known_kwargs.add(tf_name)
118 | val = tf_kwargs.get(tf_name, default)
119 | return val if val is not None else none
120 |
121 | # Convert kwargs.
122 | kwargs = dnnlib.EasyDict(
123 | z_dim = kwarg('latent_size', 512),
124 | c_dim = kwarg('label_size', 0),
125 | w_dim = kwarg('dlatent_size', 512),
126 | img_resolution = kwarg('resolution', 1024),
127 | img_channels = kwarg('num_channels', 3),
128 | mapping_kwargs = dnnlib.EasyDict(
129 | num_layers = kwarg('mapping_layers', 8),
130 | embed_features = kwarg('label_fmaps', None),
131 | layer_features = kwarg('mapping_fmaps', None),
132 | activation = kwarg('mapping_nonlinearity', 'lrelu'),
133 | lr_multiplier = kwarg('mapping_lrmul', 0.01),
134 | w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
135 | ),
136 | synthesis_kwargs = dnnlib.EasyDict(
137 | channel_base = kwarg('fmap_base', 16384) * 2,
138 | channel_max = kwarg('fmap_max', 512),
139 | num_fp16_res = kwarg('num_fp16_res', 0),
140 | conv_clamp = kwarg('conv_clamp', None),
141 | architecture = kwarg('architecture', 'skip'),
142 | resample_filter = kwarg('resample_kernel', [1,3,3,1]),
143 | use_noise = kwarg('use_noise', True),
144 | activation = kwarg('nonlinearity', 'lrelu'),
145 | ),
146 | )
147 |
148 | # Check for unknown kwargs.
149 | kwarg('truncation_psi')
150 | kwarg('truncation_cutoff')
151 | kwarg('style_mixing_prob')
152 | kwarg('structure')
153 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154 | if len(unknown_kwargs) > 0:
155 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156 |
157 | # Collect params.
158 | tf_params = _collect_tf_params(tf_G)
159 | for name, value in list(tf_params.items()):
160 | match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161 | if match:
162 | r = kwargs.img_resolution // (2 ** int(match.group(1)))
163 | tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164 | kwargs.synthesis.kwargs.architecture = 'orig'
165 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166 |
167 | # Convert params.
168 | from training import networks
169 | G = networks.Generator(**kwargs).eval().requires_grad_(False)
170 | # pylint: disable=unnecessary-lambda
171 | _populate_module_params(G,
172 | r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173 | r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174 | r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177 | r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178 | r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179 | r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180 | r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181 | r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182 | r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183 | r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184 | r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185 | r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186 | r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187 | r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188 | r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189 | r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190 | r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191 | r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192 | r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193 | r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194 | r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195 | r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196 | r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197 | r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198 | r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199 | r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200 | r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201 | r'.*\.resample_filter', None,
202 | )
203 | return G
204 |
205 | #----------------------------------------------------------------------------
206 |
207 | def convert_tf_discriminator(tf_D):
208 | if tf_D.version < 4:
209 | raise ValueError('TensorFlow pickle version too low')
210 |
211 | # Collect kwargs.
212 | tf_kwargs = tf_D.static_kwargs
213 | known_kwargs = set()
214 | def kwarg(tf_name, default=None):
215 | known_kwargs.add(tf_name)
216 | return tf_kwargs.get(tf_name, default)
217 |
218 | # Convert kwargs.
219 | kwargs = dnnlib.EasyDict(
220 | c_dim = kwarg('label_size', 0),
221 | img_resolution = kwarg('resolution', 1024),
222 | img_channels = kwarg('num_channels', 3),
223 | architecture = kwarg('architecture', 'resnet'),
224 | channel_base = kwarg('fmap_base', 16384) * 2,
225 | channel_max = kwarg('fmap_max', 512),
226 | num_fp16_res = kwarg('num_fp16_res', 0),
227 | conv_clamp = kwarg('conv_clamp', None),
228 | cmap_dim = kwarg('mapping_fmaps', None),
229 | block_kwargs = dnnlib.EasyDict(
230 | activation = kwarg('nonlinearity', 'lrelu'),
231 | resample_filter = kwarg('resample_kernel', [1,3,3,1]),
232 | freeze_layers = kwarg('freeze_layers', 0),
233 | ),
234 | mapping_kwargs = dnnlib.EasyDict(
235 | num_layers = kwarg('mapping_layers', 0),
236 | embed_features = kwarg('mapping_fmaps', None),
237 | layer_features = kwarg('mapping_fmaps', None),
238 | activation = kwarg('nonlinearity', 'lrelu'),
239 | lr_multiplier = kwarg('mapping_lrmul', 0.1),
240 | ),
241 | epilogue_kwargs = dnnlib.EasyDict(
242 | mbstd_group_size = kwarg('mbstd_group_size', None),
243 | mbstd_num_channels = kwarg('mbstd_num_features', 1),
244 | activation = kwarg('nonlinearity', 'lrelu'),
245 | ),
246 | )
247 |
248 | # Check for unknown kwargs.
249 | kwarg('structure')
250 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
251 | if len(unknown_kwargs) > 0:
252 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
253 |
254 | # Collect params.
255 | tf_params = _collect_tf_params(tf_D)
256 | for name, value in list(tf_params.items()):
257 | match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
258 | if match:
259 | r = kwargs.img_resolution // (2 ** int(match.group(1)))
260 | tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
261 | kwargs.architecture = 'orig'
262 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
263 |
264 | # Convert params.
265 | from training import networks
266 | D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
267 | # pylint: disable=unnecessary-lambda
268 | _populate_module_params(D,
269 | r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
270 | r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
271 | r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
272 | r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
273 | r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
274 | r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
275 | r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
276 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
277 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
278 | r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
279 | r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
280 | r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
281 | r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
282 | r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
283 | r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
284 | r'.*\.resample_filter', None,
285 | )
286 | return D
287 |
288 | #----------------------------------------------------------------------------
289 |
290 | @click.command()
291 | @click.option('--source', help='Input pickle', required=True, metavar='PATH')
292 | @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
293 | @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
294 | def convert_network_pickle(source, dest, force_fp16):
295 | """Convert legacy network pickle into the native PyTorch format.
296 |
297 | The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
298 | It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
299 |
300 | Example:
301 |
302 | \b
303 | python legacy.py \\
304 | --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
305 | --dest=stylegan2-cat-config-f.pkl
306 | """
307 | print(f'Loading "{source}"...')
308 | with dnnlib.util.open_url(source) as f:
309 | data = load_network_pkl(f, force_fp16=force_fp16)
310 | print(f'Saving "{dest}"...')
311 | with open(dest, 'wb') as f:
312 | pickle.dump(data, f)
313 | print('Done.')
314 |
315 | #----------------------------------------------------------------------------
316 |
317 | if __name__ == "__main__":
318 | convert_network_pickle() # pylint: disable=no-value-for-parameter
319 |
320 | #----------------------------------------------------------------------------
321 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gradio>=3.28.1
2 | Pydantic
3 | opencv-python
4 | scipy>=1.7.3
5 | ninja==1.11.1
6 | lpips
7 |
8 |
--------------------------------------------------------------------------------
/resources/Teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/Teaser.png
--------------------------------------------------------------------------------
/resources/comparison_diffusion_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/comparison_diffusion_1.png
--------------------------------------------------------------------------------
/resources/comparison_diffusion_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/comparison_diffusion_2.png
--------------------------------------------------------------------------------
/resources/comparison_gan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/comparison_gan.png
--------------------------------------------------------------------------------
/resources/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/fig1.png
--------------------------------------------------------------------------------
/resources/logo2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPengYang/FreeDrag/a35d9882408a0405a6137a0dec13a308f5e5d3f7/resources/logo2.png
--------------------------------------------------------------------------------
/resources/style.css:
--------------------------------------------------------------------------------
1 | body {
2 | margin: 0 auto;
3 | font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif;
4 | font-size: 18px;
5 | }
6 |
7 | .container {
8 | margin: 0 auto;
9 | width: 100%;
10 | max-width: 1100px;
11 | text-align: center;
12 | display: block;
13 | }
14 |
15 | .title {
16 | font-size: 36px;
17 | margin-top: 20px;
18 | width: 90%;
19 | }
20 |
21 | .venue {
22 | font-size: 22px;
23 | margin-top: 20px;
24 | width: 90%;
25 | }
26 |
27 | .author {
28 | width: 95%;
29 | max-width: 300px;
30 | font-size: 20px;
31 | }
32 |
33 | .affiliation {
34 | font-size: 20px;
35 | width: 95%;
36 | max-width: 450px;
37 | }
38 |
39 | .links {
40 | font-size: 22px;
41 | width: 95%;
42 | max-width: 150px;
43 | }
44 |
45 | .video-container {
46 | position: relative;
47 | overflow: hidden;
48 | width: 80%;
49 | padding-top: 45%;
50 | /* Formula for 16:9 Aspect Ratio: width * 9 / 16 */
51 | }
52 |
53 | .video-container iframe {
54 | position: absolute;
55 | top: 0;
56 | left: 0;
57 | bottom: 0;
58 | right: 0;
59 | width: 100%;
60 | height: 100%;
61 | }
62 |
63 | .paper-thumbnail {
64 | margin: 0 auto;
65 | width: 40%;
66 | max-width: 250px;
67 | display: inline-block;
68 | vertical-align: top;
69 | padding: 2% 10% 4% 0;
70 | }
71 |
72 | .paper-info {
73 | width: 45%;
74 | display: inline-block;
75 | vertical-align: top;
76 | }
77 |
78 | @media (max-width: 999px) {
79 | .paper-thumbnail {
80 | width: 60%;
81 | }
82 |
83 | .paper-info {
84 | width: 80%;
85 | }
86 | }
87 |
88 |
89 | p {
90 | text-align: left;
91 | margin: 0 auto;
92 | margin-bottom: 10px;
93 | }
94 |
95 | h1 {
96 | font-weight: 300;
97 | text-align: center;
98 | }
99 |
100 | h2 {
101 | text-align: center;
102 | }
103 |
104 | h3 {
105 | text-align: left;
106 | }
107 |
108 | h4 {
109 | text-align: left;
110 | }
111 |
112 | h5 {
113 | text-align: left;
114 | }
115 |
116 | div {
117 | display: inline-block;
118 | }
119 |
120 | hr {
121 | border: 0;
122 | height: 1px;
123 | background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
124 | width: 90%;
125 | }
126 |
127 | pre {
128 | overflow-x: auto;
129 | text-align: left;
130 | border: 1px solid grey;
131 | border-radius: 3px;
132 | background: #eeeeee;
133 | padding: 5px 5px 5px 10px;
134 | line-height: 1.2;
135 | white-space: pre-wrap;
136 | }
137 |
138 | pre code {
139 | text-align: left;
140 | word-wrap: normal;
141 | white-space: pre-wrap;
142 | font-size: 14px;
143 | }
144 |
145 | a:link,
146 | a:visited {
147 | color: #1367a7;
148 | text-decoration: none;
149 | }
150 |
151 | a:hover {
152 | color: #208799;
153 | }
154 |
155 | .layered-paper-big {
156 | /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
157 | box-shadow:
158 | 0px 0px 1px 1px rgba(0, 0, 0, 0.35),
159 | /* The top layer shadow */
160 | 5px 5px 0 0px #fff,
161 | /* The second layer */
162 | 5px 5px 1px 1px rgba(0, 0, 0, 0.35),
163 | /* The second layer shadow */
164 | 10px 10px 0 0px #fff,
165 | /* The third layer */
166 | 10px 10px 1px 1px rgba(0, 0, 0, 0.35),
167 | /* The third layer shadow */
168 | 15px 15px 0 0px #fff,
169 | /* The fourth layer */
170 | 15px 15px 1px 1px rgba(0, 0, 0, 0.35),
171 | /* The fourth layer shadow */
172 | 20px 20px 0 0px #fff,
173 | /* The fifth layer */
174 | 20px 20px 1px 1px rgba(0, 0, 0, 0.35),
175 | /* The fifth layer shadow */
176 | 25px 25px 0 0px #fff,
177 | /* The fifth layer */
178 | 25px 25px 1px 1px rgba(0, 0, 0, 0.35);
179 | /* The fifth layer shadow */
180 | margin-left: 10px;
181 | margin-right: 45px;
182 | }
--------------------------------------------------------------------------------
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/torch_utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import glob
11 | import torch
12 | import torch.utils.cpp_extension
13 | import importlib
14 | import hashlib
15 | import shutil
16 | from pathlib import Path
17 |
18 | from torch.utils.file_baton import FileBaton
19 |
20 | #----------------------------------------------------------------------------
21 | # Global options.
22 |
23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24 |
25 | #----------------------------------------------------------------------------
26 | # Internal helper funcs.
27 |
28 | def _find_compiler_bindir():
29 | patterns = [
30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34 | ]
35 | for pattern in patterns:
36 | matches = sorted(glob.glob(pattern))
37 | if len(matches):
38 | return matches[-1]
39 | return None
40 |
41 | #----------------------------------------------------------------------------
42 | # Main entry point for compiling and loading C++/CUDA plugins.
43 |
44 | _cached_plugins = dict()
45 |
46 | def get_plugin(module_name, sources, **build_kwargs):
47 | assert verbosity in ['none', 'brief', 'full']
48 |
49 | # Already cached?
50 | if module_name in _cached_plugins:
51 | return _cached_plugins[module_name]
52 |
53 | # Print status.
54 | if verbosity == 'full':
55 | print(f'Setting up PyTorch plugin "{module_name}"...')
56 | elif verbosity == 'brief':
57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58 |
59 | try: # pylint: disable=too-many-nested-blocks
60 | # Make sure we can find the necessary compiler binaries.
61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62 | compiler_bindir = _find_compiler_bindir()
63 | if compiler_bindir is None:
64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65 | os.environ['PATH'] += ';' + compiler_bindir
66 |
67 | # Compile and load.
68 | verbose_build = (verbosity == 'full')
69 |
70 | # Incremental build md5sum trickery. Copies all the input source files
71 | # into a cached build directory under a combined md5 digest of the input
72 | # source files. Copying is done only if the combined digest has changed.
73 | # This keeps input file timestamps and filenames the same as in previous
74 | # extension builds, allowing for fast incremental rebuilds.
75 | #
76 | # This optimization is done only in case all the source files reside in
77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78 | # environment variable is set (we take this as a signal that the user
79 | # actually cares about this.)
80 | source_dirs_set = set(os.path.dirname(source) for source in sources)
81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83 |
84 | # Compute a combined hash digest for all source files in the same
85 | # custom op directory (usually .cu, .cpp, .py and .h files).
86 | hash_md5 = hashlib.md5()
87 | for src in all_source_files:
88 | with open(src, 'rb') as f:
89 | hash_md5.update(f.read())
90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92 |
93 | if not os.path.isdir(digest_build_dir):
94 | os.makedirs(digest_build_dir, exist_ok=True)
95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96 | if baton.try_acquire():
97 | try:
98 | for src in all_source_files:
99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100 | finally:
101 | baton.release()
102 | else:
103 | # Someone else is copying source files under the digest dir,
104 | # wait until done and continue.
105 | baton.wait()
106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108 | verbose=verbose_build, sources=digest_sources, **build_kwargs)
109 | else:
110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111 | module = importlib.import_module(module_name)
112 |
113 | except:
114 | if verbosity == 'brief':
115 | print('Failed!')
116 | raise
117 |
118 | # Print status and add to cache.
119 | if verbosity == 'full':
120 | print(f'Done setting up PyTorch plugin "{module_name}".')
121 | elif verbosity == 'brief':
122 | print('Done.')
123 | _cached_plugins[module_name] = module
124 | return module
125 |
126 | #----------------------------------------------------------------------------
127 |
--------------------------------------------------------------------------------
/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import re
10 | import contextlib
11 | import numpy as np
12 | import torch
13 | import warnings
14 | import dnnlib
15 |
16 | #----------------------------------------------------------------------------
17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18 | # same constant is used multiple times.
19 |
20 | _constant_cache = dict()
21 |
22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23 | value = np.asarray(value)
24 | if shape is not None:
25 | shape = tuple(shape)
26 | if dtype is None:
27 | dtype = torch.get_default_dtype()
28 | if device is None:
29 | device = torch.device('cpu')
30 | if memory_format is None:
31 | memory_format = torch.contiguous_format
32 |
33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34 | tensor = _constant_cache.get(key, None)
35 | if tensor is None:
36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37 | if shape is not None:
38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39 | tensor = tensor.contiguous(memory_format=memory_format)
40 | _constant_cache[key] = tensor
41 | return tensor
42 |
43 | #----------------------------------------------------------------------------
44 | # Replace NaN/Inf with specified numerical values.
45 |
46 | try:
47 | nan_to_num = torch.nan_to_num # 1.8.0a0
48 | except AttributeError:
49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50 | assert isinstance(input, torch.Tensor)
51 | if posinf is None:
52 | posinf = torch.finfo(input.dtype).max
53 | if neginf is None:
54 | neginf = torch.finfo(input.dtype).min
55 | assert nan == 0
56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57 |
58 | #----------------------------------------------------------------------------
59 | # Symbolic assert.
60 |
61 | try:
62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63 | except AttributeError:
64 | symbolic_assert = torch.Assert # 1.7.0
65 |
66 | #----------------------------------------------------------------------------
67 | # Context manager to suppress known warnings in torch.jit.trace().
68 |
69 | class suppress_tracer_warnings(warnings.catch_warnings):
70 | def __enter__(self):
71 | super().__enter__()
72 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
73 | return self
74 |
75 | #----------------------------------------------------------------------------
76 | # Assert that the shape of a tensor matches the given list of integers.
77 | # None indicates that the size of a dimension is allowed to vary.
78 | # Performs symbolic assertion when used in torch.jit.trace().
79 |
80 | def assert_shape(tensor, ref_shape):
81 | if tensor.ndim != len(ref_shape):
82 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
83 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
84 | if ref_size is None:
85 | pass
86 | elif isinstance(ref_size, torch.Tensor):
87 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
88 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
89 | elif isinstance(size, torch.Tensor):
90 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
91 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
92 | elif size != ref_size:
93 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
94 |
95 | #----------------------------------------------------------------------------
96 | # Function decorator that calls torch.autograd.profiler.record_function().
97 |
98 | def profiled_function(fn):
99 | def decorator(*args, **kwargs):
100 | with torch.autograd.profiler.record_function(fn.__name__):
101 | return fn(*args, **kwargs)
102 | decorator.__name__ = fn.__name__
103 | return decorator
104 |
105 | #----------------------------------------------------------------------------
106 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
107 | # indefinitely, shuffling items as it goes.
108 |
109 | class InfiniteSampler(torch.utils.data.Sampler):
110 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
111 | assert len(dataset) > 0
112 | assert num_replicas > 0
113 | assert 0 <= rank < num_replicas
114 | assert 0 <= window_size <= 1
115 | super().__init__(dataset)
116 | self.dataset = dataset
117 | self.rank = rank
118 | self.num_replicas = num_replicas
119 | self.shuffle = shuffle
120 | self.seed = seed
121 | self.window_size = window_size
122 |
123 | def __iter__(self):
124 | order = np.arange(len(self.dataset))
125 | rnd = None
126 | window = 0
127 | if self.shuffle:
128 | rnd = np.random.RandomState(self.seed)
129 | rnd.shuffle(order)
130 | window = int(np.rint(order.size * self.window_size))
131 |
132 | idx = 0
133 | while True:
134 | i = idx % order.size
135 | if idx % self.num_replicas == self.rank:
136 | yield order[i]
137 | if window >= 2:
138 | j = (i - rnd.randint(window)) % order.size
139 | order[i], order[j] = order[j], order[i]
140 | idx += 1
141 |
142 | #----------------------------------------------------------------------------
143 | # Utilities for operating with torch.nn.Module parameters and buffers.
144 |
145 | def params_and_buffers(module):
146 | assert isinstance(module, torch.nn.Module)
147 | return list(module.parameters()) + list(module.buffers())
148 |
149 | def named_params_and_buffers(module):
150 | assert isinstance(module, torch.nn.Module)
151 | return list(module.named_parameters()) + list(module.named_buffers())
152 |
153 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
154 | assert isinstance(src_module, torch.nn.Module)
155 | assert isinstance(dst_module, torch.nn.Module)
156 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
157 | for name, tensor in named_params_and_buffers(dst_module):
158 | assert (name in src_tensors) or (not require_all)
159 | if name in src_tensors:
160 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
161 |
162 | #----------------------------------------------------------------------------
163 | # Context manager for easily enabling/disabling DistributedDataParallel
164 | # synchronization.
165 |
166 | @contextlib.contextmanager
167 | def ddp_sync(module, sync):
168 | assert isinstance(module, torch.nn.Module)
169 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
170 | yield
171 | else:
172 | with module.no_sync():
173 | yield
174 |
175 | #----------------------------------------------------------------------------
176 | # Check DistributedDataParallel consistency across processes.
177 |
178 | def check_ddp_consistency(module, ignore_regex=None):
179 | assert isinstance(module, torch.nn.Module)
180 | for name, tensor in named_params_and_buffers(module):
181 | fullname = type(module).__name__ + '.' + name
182 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
183 | continue
184 | tensor = tensor.detach()
185 | other = tensor.clone()
186 | torch.distributed.broadcast(tensor=other, src=0)
187 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
188 |
189 | #----------------------------------------------------------------------------
190 | # Print summary table of module hierarchy.
191 |
192 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
193 | assert isinstance(module, torch.nn.Module)
194 | assert not isinstance(module, torch.jit.ScriptModule)
195 | assert isinstance(inputs, (tuple, list))
196 |
197 | # Register hooks.
198 | entries = []
199 | nesting = [0]
200 | def pre_hook(_mod, _inputs):
201 | nesting[0] += 1
202 | def post_hook(mod, _inputs, outputs):
203 | nesting[0] -= 1
204 | if nesting[0] <= max_nesting:
205 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
206 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
207 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
208 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
209 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
210 |
211 | # Run module.
212 | outputs = module(*inputs)
213 | for hook in hooks:
214 | hook.remove()
215 |
216 | # Identify unique outputs, parameters, and buffers.
217 | tensors_seen = set()
218 | for e in entries:
219 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
220 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
221 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
222 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
223 |
224 | # Filter out redundant entries.
225 | if skip_redundant:
226 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
227 |
228 | # Construct table.
229 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
230 | rows += [['---'] * len(rows[0])]
231 | param_total = 0
232 | buffer_total = 0
233 | submodule_names = {mod: name for name, mod in module.named_modules()}
234 | for e in entries:
235 | name = '' if e.mod is module else submodule_names[e.mod]
236 | param_size = sum(t.numel() for t in e.unique_params)
237 | buffer_size = sum(t.numel() for t in e.unique_buffers)
238 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
239 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
240 | rows += [[
241 | name + (':0' if len(e.outputs) >= 2 else ''),
242 | str(param_size) if param_size else '-',
243 | str(buffer_size) if buffer_size else '-',
244 | (output_shapes + ['-'])[0],
245 | (output_dtypes + ['-'])[0],
246 | ]]
247 | for idx in range(1, len(e.outputs)):
248 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
249 | param_total += param_size
250 | buffer_total += buffer_size
251 | rows += [['---'] * len(rows[0])]
252 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
253 |
254 | # Print table.
255 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
256 | print()
257 | for row in rows:
258 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
259 | print()
260 | return outputs
261 |
262 | #----------------------------------------------------------------------------
263 |
--------------------------------------------------------------------------------
/torch_utils/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "bias_act.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17 | {
18 | if (x.dim() != y.dim())
19 | return false;
20 | for (int64_t i = 0; i < x.dim(); i++)
21 | {
22 | if (x.size(i) != y.size(i))
23 | return false;
24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25 | return false;
26 | }
27 | return true;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 |
32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33 | {
34 | // Validate arguments.
35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
45 |
46 | // Validate layout.
47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52 |
53 | // Create output tensor.
54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55 | torch::Tensor y = torch::empty_like(x);
56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57 |
58 | // Initialize CUDA kernel parameters.
59 | bias_act_kernel_params p;
60 | p.x = x.data_ptr();
61 | p.b = (b.numel()) ? b.data_ptr() : NULL;
62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65 | p.y = y.data_ptr();
66 | p.grad = grad;
67 | p.act = act;
68 | p.alpha = alpha;
69 | p.gain = gain;
70 | p.clamp = clamp;
71 | p.sizeX = (int)x.numel();
72 | p.sizeB = (int)b.numel();
73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74 |
75 | // Choose CUDA kernel.
76 | void* kernel;
77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78 | {
79 | kernel = choose_bias_act_kernel(p);
80 | });
81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82 |
83 | // Launch CUDA kernel.
84 | p.loopX = 4;
85 | int blockSize = 4 * 32;
86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87 | void* args[] = {&p};
88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89 | return y;
90 | }
91 |
92 | //------------------------------------------------------------------------
93 |
94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95 | {
96 | m.def("bias_act", &bias_act);
97 | }
98 |
99 | //------------------------------------------------------------------------
100 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include "bias_act.h"
11 |
12 | //------------------------------------------------------------------------
13 | // Helpers.
14 |
15 | template struct InternalType;
16 | template <> struct InternalType { typedef double scalar_t; };
17 | template <> struct InternalType { typedef float scalar_t; };
18 | template <> struct InternalType { typedef float scalar_t; };
19 |
20 | //------------------------------------------------------------------------
21 | // CUDA kernel.
22 |
23 | template
24 | __global__ void bias_act_kernel(bias_act_kernel_params p)
25 | {
26 | typedef typename InternalType::scalar_t scalar_t;
27 | int G = p.grad;
28 | scalar_t alpha = (scalar_t)p.alpha;
29 | scalar_t gain = (scalar_t)p.gain;
30 | scalar_t clamp = (scalar_t)p.clamp;
31 | scalar_t one = (scalar_t)1;
32 | scalar_t two = (scalar_t)2;
33 | scalar_t expRange = (scalar_t)80;
34 | scalar_t halfExpRange = (scalar_t)40;
35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37 |
38 | // Loop over elements.
39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41 | {
42 | // Load.
43 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48 | scalar_t yy = (gain != 0) ? yref / gain : 0;
49 | scalar_t y = 0;
50 |
51 | // Apply bias.
52 | ((G == 0) ? x : xref) += b;
53 |
54 | // linear
55 | if (A == 1)
56 | {
57 | if (G == 0) y = x;
58 | if (G == 1) y = x;
59 | }
60 |
61 | // relu
62 | if (A == 2)
63 | {
64 | if (G == 0) y = (x > 0) ? x : 0;
65 | if (G == 1) y = (yy > 0) ? x : 0;
66 | }
67 |
68 | // lrelu
69 | if (A == 3)
70 | {
71 | if (G == 0) y = (x > 0) ? x : x * alpha;
72 | if (G == 1) y = (yy > 0) ? x : x * alpha;
73 | }
74 |
75 | // tanh
76 | if (A == 4)
77 | {
78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79 | if (G == 1) y = x * (one - yy * yy);
80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81 | }
82 |
83 | // sigmoid
84 | if (A == 5)
85 | {
86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87 | if (G == 1) y = x * yy * (one - yy);
88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89 | }
90 |
91 | // elu
92 | if (A == 6)
93 | {
94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97 | }
98 |
99 | // selu
100 | if (A == 7)
101 | {
102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105 | }
106 |
107 | // softplus
108 | if (A == 8)
109 | {
110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111 | if (G == 1) y = x * (one - exp(-yy));
112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113 | }
114 |
115 | // swish
116 | if (A == 9)
117 | {
118 | if (G == 0)
119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120 | else
121 | {
122 | scalar_t c = exp(xref);
123 | scalar_t d = c + one;
124 | if (G == 1)
125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126 | else
127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129 | }
130 | }
131 |
132 | // Apply gain.
133 | y *= gain * dy;
134 |
135 | // Clamp.
136 | if (clamp >= 0)
137 | {
138 | if (G == 0)
139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140 | else
141 | y = (yref > -clamp & yref < clamp) ? y : 0;
142 | }
143 |
144 | // Store.
145 | ((T*)p.y)[xi] = (T)y;
146 | }
147 | }
148 |
149 | //------------------------------------------------------------------------
150 | // CUDA kernel selection.
151 |
152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153 | {
154 | if (p.act == 1) return (void*)bias_act_kernel;
155 | if (p.act == 2) return (void*)bias_act_kernel;
156 | if (p.act == 3) return (void*)bias_act_kernel;
157 | if (p.act == 4) return (void*)bias_act_kernel;
158 | if (p.act == 5) return (void*)bias_act_kernel;
159 | if (p.act == 6) return (void*)bias_act_kernel;
160 | if (p.act == 7) return (void*)bias_act_kernel;
161 | if (p.act == 8) return (void*)bias_act_kernel;
162 | if (p.act == 9) return (void*)bias_act_kernel;
163 | return NULL;
164 | }
165 |
166 | //------------------------------------------------------------------------
167 | // Template specializations.
168 |
169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
172 |
173 | //------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom PyTorch ops for efficient bias and activation."""
10 |
11 | import os
12 | import warnings
13 | import numpy as np
14 | import torch
15 | import dnnlib
16 | import traceback
17 |
18 | from .. import custom_ops
19 | from .. import misc
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | activation_funcs = {
24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
33 | }
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | _inited = False
38 | _plugin = None
39 | _null_tensor = torch.empty([0])
40 |
41 | def _init():
42 | global _inited, _plugin
43 | if not _inited:
44 | _inited = True
45 | sources = ['bias_act.cpp', 'bias_act.cu']
46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
47 | try:
48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
49 | except:
50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
51 | return _plugin is not None
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
56 | r"""Fused bias and activation function.
57 |
58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
59 | and scales the result by `gain`. Each of the steps is optional. In most cases,
60 | the fused op is considerably more efficient than performing the same calculation
61 | using standard PyTorch ops. It supports first and second order gradients,
62 | but not third order gradients.
63 |
64 | Args:
65 | x: Input activation tensor. Can be of any shape.
66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
67 | as `x`. The shape must be known, and it must match the dimension of `x`
68 | corresponding to `dim`.
69 | dim: The dimension in `x` corresponding to the elements of `b`.
70 | The value of `dim` is ignored if `b` is not specified.
71 | act: Name of the activation function to evaluate, or `"linear"` to disable.
72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
73 | See `activation_funcs` for a full list. `None` is not allowed.
74 | alpha: Shape parameter for the activation function, or `None` to use the default.
75 | gain: Scaling factor for the output tensor, or `None` to use default.
76 | See `activation_funcs` for the default scaling of each activation function.
77 | If unsure, consider specifying 1.
78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
79 | the clamping (default).
80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
81 |
82 | Returns:
83 | Tensor of the same shape and datatype as `x`.
84 | """
85 | assert isinstance(x, torch.Tensor)
86 | assert impl in ['ref', 'cuda']
87 | if impl == 'cuda' and x.device.type == 'cuda' and _init():
88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
90 |
91 | #----------------------------------------------------------------------------
92 |
93 | @misc.profiled_function
94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
96 | """
97 | assert isinstance(x, torch.Tensor)
98 | assert clamp is None or clamp >= 0
99 | spec = activation_funcs[act]
100 | alpha = float(alpha if alpha is not None else spec.def_alpha)
101 | gain = float(gain if gain is not None else spec.def_gain)
102 | clamp = float(clamp if clamp is not None else -1)
103 |
104 | # Add bias.
105 | if b is not None:
106 | assert isinstance(b, torch.Tensor) and b.ndim == 1
107 | assert 0 <= dim < x.ndim
108 | assert b.shape[0] == x.shape[dim]
109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
110 |
111 | # Evaluate activation function.
112 | alpha = float(alpha)
113 | x = spec.func(x, alpha=alpha)
114 |
115 | # Scale by gain.
116 | gain = float(gain)
117 | if gain != 1:
118 | x = x * gain
119 |
120 | # Clamp.
121 | if clamp >= 0:
122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
123 | return x
124 |
125 | #----------------------------------------------------------------------------
126 |
127 | _bias_act_cuda_cache = dict()
128 |
129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
130 | """Fast CUDA implementation of `bias_act()` using custom ops.
131 | """
132 | # Parse arguments.
133 | assert clamp is None or clamp >= 0
134 | spec = activation_funcs[act]
135 | alpha = float(alpha if alpha is not None else spec.def_alpha)
136 | gain = float(gain if gain is not None else spec.def_gain)
137 | clamp = float(clamp if clamp is not None else -1)
138 |
139 | # Lookup from cache.
140 | key = (dim, act, alpha, gain, clamp)
141 | if key in _bias_act_cuda_cache:
142 | return _bias_act_cuda_cache[key]
143 |
144 | # Forward op.
145 | class BiasActCuda(torch.autograd.Function):
146 | @staticmethod
147 | def forward(ctx, x, b): # pylint: disable=arguments-differ
148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
149 | x = x.contiguous(memory_format=ctx.memory_format)
150 | b = b.contiguous() if b is not None else _null_tensor
151 | y = x
152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
154 | ctx.save_for_backward(
155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
157 | y if 'y' in spec.ref else _null_tensor)
158 | return y
159 |
160 | @staticmethod
161 | def backward(ctx, dy): # pylint: disable=arguments-differ
162 | dy = dy.contiguous(memory_format=ctx.memory_format)
163 | x, b, y = ctx.saved_tensors
164 | dx = None
165 | db = None
166 |
167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
168 | dx = dy
169 | if act != 'linear' or gain != 1 or clamp >= 0:
170 | dx = BiasActCudaGrad.apply(dy, x, b, y)
171 |
172 | if ctx.needs_input_grad[1]:
173 | db = dx.sum([i for i in range(dx.ndim) if i != dim])
174 |
175 | return dx, db
176 |
177 | # Backward op.
178 | class BiasActCudaGrad(torch.autograd.Function):
179 | @staticmethod
180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
183 | ctx.save_for_backward(
184 | dy if spec.has_2nd_grad else _null_tensor,
185 | x, b, y)
186 | return dx
187 |
188 | @staticmethod
189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ
190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
191 | dy, x, b, y = ctx.saved_tensors
192 | d_dy = None
193 | d_x = None
194 | d_b = None
195 | d_y = None
196 |
197 | if ctx.needs_input_grad[0]:
198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
199 |
200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
202 |
203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]:
204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
205 |
206 | return d_dy, d_x, d_b, d_y
207 |
208 | # Add to cache.
209 | _bias_act_cuda_cache[key] = BiasActCuda
210 | return BiasActCuda
211 |
212 | #----------------------------------------------------------------------------
213 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.conv2d` that supports
10 | arbitrarily high order gradients with zero performance penalty."""
11 |
12 | import warnings
13 | import contextlib
14 | import torch
15 |
16 | # pylint: disable=redefined-builtin
17 | # pylint: disable=arguments-differ
18 | # pylint: disable=protected-access
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | enabled = False # Enable the custom op by setting this to true.
23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24 |
25 | @contextlib.contextmanager
26 | def no_weight_gradients():
27 | global weight_gradients_disabled
28 | old = weight_gradients_disabled
29 | weight_gradients_disabled = True
30 | yield
31 | weight_gradients_disabled = old
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36 | if _should_use_custom_op(input):
37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39 |
40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41 | if _should_use_custom_op(input):
42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44 |
45 | #----------------------------------------------------------------------------
46 |
47 | def _should_use_custom_op(input):
48 | assert isinstance(input, torch.Tensor)
49 | if (not enabled) or (not torch.backends.cudnn.enabled):
50 | return False
51 | if input.device.type != 'cuda':
52 | return False
53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
54 | return True
55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
56 | return False
57 |
58 | def _tuple_of_ints(xs, ndim):
59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
60 | assert len(xs) == ndim
61 | assert all(isinstance(x, int) for x in xs)
62 | return xs
63 |
64 | #----------------------------------------------------------------------------
65 |
66 | _conv2d_gradfix_cache = dict()
67 |
68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69 | # Parse arguments.
70 | ndim = 2
71 | weight_shape = tuple(weight_shape)
72 | stride = _tuple_of_ints(stride, ndim)
73 | padding = _tuple_of_ints(padding, ndim)
74 | output_padding = _tuple_of_ints(output_padding, ndim)
75 | dilation = _tuple_of_ints(dilation, ndim)
76 |
77 | # Lookup from cache.
78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79 | if key in _conv2d_gradfix_cache:
80 | return _conv2d_gradfix_cache[key]
81 |
82 | # Validate arguments.
83 | assert groups >= 1
84 | assert len(weight_shape) == ndim + 2
85 | assert all(stride[i] >= 1 for i in range(ndim))
86 | assert all(padding[i] >= 0 for i in range(ndim))
87 | assert all(dilation[i] >= 0 for i in range(ndim))
88 | if not transpose:
89 | assert all(output_padding[i] == 0 for i in range(ndim))
90 | else: # transpose
91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92 |
93 | # Helpers.
94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95 | def calc_output_padding(input_shape, output_shape):
96 | if transpose:
97 | return [0, 0]
98 | return [
99 | input_shape[i + 2]
100 | - (output_shape[i + 2] - 1) * stride[i]
101 | - (1 - 2 * padding[i])
102 | - dilation[i] * (weight_shape[i + 2] - 1)
103 | for i in range(ndim)
104 | ]
105 |
106 | # Forward & backward.
107 | class Conv2d(torch.autograd.Function):
108 | @staticmethod
109 | def forward(ctx, input, weight, bias):
110 | assert weight.shape == weight_shape
111 | if not transpose:
112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
113 | else: # transpose
114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
115 | ctx.save_for_backward(input, weight)
116 | return output
117 |
118 | @staticmethod
119 | def backward(ctx, grad_output):
120 | input, weight = ctx.saved_tensors
121 | grad_input = None
122 | grad_weight = None
123 | grad_bias = None
124 |
125 | if ctx.needs_input_grad[0]:
126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
128 | assert grad_input.shape == input.shape
129 |
130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
131 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
132 | assert grad_weight.shape == weight_shape
133 |
134 | if ctx.needs_input_grad[2]:
135 | grad_bias = grad_output.sum([0, 2, 3])
136 |
137 | return grad_input, grad_weight, grad_bias
138 |
139 | # Gradient with respect to the weights.
140 | class Conv2dGradWeight(torch.autograd.Function):
141 | @staticmethod
142 | def forward(ctx, grad_output, input):
143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
146 | assert grad_weight.shape == weight_shape
147 | ctx.save_for_backward(grad_output, input)
148 | return grad_weight
149 |
150 | @staticmethod
151 | def backward(ctx, grad2_grad_weight):
152 | grad_output, input = ctx.saved_tensors
153 | grad2_grad_output = None
154 | grad2_input = None
155 |
156 | if ctx.needs_input_grad[0]:
157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
158 | assert grad2_grad_output.shape == grad_output.shape
159 |
160 | if ctx.needs_input_grad[1]:
161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
163 | assert grad2_input.shape == input.shape
164 |
165 | return grad2_grad_output, grad2_input
166 |
167 | _conv2d_gradfix_cache[key] = Conv2d
168 | return Conv2d
169 |
170 | #----------------------------------------------------------------------------
171 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """2D convolution with optional up/downsampling."""
10 |
11 | import torch
12 |
13 | from .. import misc
14 | from . import conv2d_gradfix
15 | from . import upfirdn2d
16 | from .upfirdn2d import _parse_padding
17 | from .upfirdn2d import _get_filter_size
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def _get_weight_shape(w):
22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23 | shape = [int(sz) for sz in w.shape]
24 | misc.assert_shape(w, shape)
25 | return shape
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31 | """
32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
33 |
34 | # Flip weight if requested.
35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36 | w = w.flip([2, 3])
37 |
38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels.
40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
42 | if out_channels <= 4 and groups == 1:
43 | in_shape = x.shape
44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
46 | else:
47 | x = x.to(memory_format=torch.contiguous_format)
48 | w = w.to(memory_format=torch.contiguous_format)
49 | x = conv2d_gradfix.conv2d(x, w, groups=groups)
50 | return x.to(memory_format=torch.channels_last)
51 |
52 | # Otherwise => execute using conv2d_gradfix.
53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
54 | return op(x, w, stride=stride, padding=padding, groups=groups)
55 |
56 | #----------------------------------------------------------------------------
57 |
58 | @misc.profiled_function
59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
60 | r"""2D convolution with optional up/downsampling.
61 |
62 | Padding is performed only once at the beginning, not between the operations.
63 |
64 | Args:
65 | x: Input tensor of shape
66 | `[batch_size, in_channels, in_height, in_width]`.
67 | w: Weight tensor of shape
68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
70 | calling upfirdn2d.setup_filter(). None = identity (default).
71 | up: Integer upsampling factor (default: 1).
72 | down: Integer downsampling factor (default: 1).
73 | padding: Padding with respect to the upsampled image. Can be a single number
74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
75 | (default: 0).
76 | groups: Split input channels into N groups (default: 1).
77 | flip_weight: False = convolution, True = correlation (default: True).
78 | flip_filter: False = convolution, True = correlation (default: False).
79 |
80 | Returns:
81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
82 | """
83 | # Validate arguments.
84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
87 | assert isinstance(up, int) and (up >= 1)
88 | assert isinstance(down, int) and (down >= 1)
89 | assert isinstance(groups, int) and (groups >= 1)
90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
91 | fw, fh = _get_filter_size(f)
92 | px0, px1, py0, py1 = _parse_padding(padding)
93 |
94 | # Adjust padding to account for up/downsampling.
95 | if up > 1:
96 | px0 += (fw + up - 1) // 2
97 | px1 += (fw - up) // 2
98 | py0 += (fh + up - 1) // 2
99 | py1 += (fh - up) // 2
100 | if down > 1:
101 | px0 += (fw - down + 1) // 2
102 | px1 += (fw - down) // 2
103 | py0 += (fh - down + 1) // 2
104 | py1 += (fh - down) // 2
105 |
106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
107 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
110 | return x
111 |
112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
113 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
116 | return x
117 |
118 | # Fast path: downsampling only => use strided convolution.
119 | if down > 1 and up == 1:
120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
122 | return x
123 |
124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
125 | if up > 1:
126 | if groups == 1:
127 | w = w.transpose(0, 1)
128 | else:
129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
130 | w = w.transpose(1, 2)
131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
132 | px0 -= kw - 1
133 | px1 -= kw - up
134 | py0 -= kh - 1
135 | py1 -= kh - up
136 | pxt = max(min(-px0, -px1), 0)
137 | pyt = max(min(-py0, -py1), 0)
138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
140 | if down > 1:
141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
142 | return x
143 |
144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
145 | if up == 1 and down == 1:
146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
148 |
149 | # Fallback: Generic reference implementation.
150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
152 | if down > 1:
153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
154 | return x
155 |
156 | #----------------------------------------------------------------------------
157 |
--------------------------------------------------------------------------------
/torch_utils/ops/fma.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10 |
11 | import torch
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | def fma(a, b, c): # => a * b + c
16 | return _FusedMultiplyAdd.apply(a, b, c)
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21 | @staticmethod
22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23 | out = torch.addcmul(c, a, b)
24 | ctx.save_for_backward(a, b)
25 | ctx.c_shape = c.shape
26 | return out
27 |
28 | @staticmethod
29 | def backward(ctx, dout): # pylint: disable=arguments-differ
30 | a, b = ctx.saved_tensors
31 | c_shape = ctx.c_shape
32 | da = None
33 | db = None
34 | dc = None
35 |
36 | if ctx.needs_input_grad[0]:
37 | da = _unbroadcast(dout * b, a.shape)
38 |
39 | if ctx.needs_input_grad[1]:
40 | db = _unbroadcast(dout * a, b.shape)
41 |
42 | if ctx.needs_input_grad[2]:
43 | dc = _unbroadcast(dout, c_shape)
44 |
45 | return da, db, dc
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _unbroadcast(x, shape):
50 | extra_dims = x.ndim - len(shape)
51 | assert extra_dims >= 0
52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53 | if len(dim):
54 | x = x.sum(dim=dim, keepdim=True)
55 | if extra_dims:
56 | x = x.reshape(-1, *x.shape[extra_dims+1:])
57 | assert x.shape == shape
58 | return x
59 |
60 | #----------------------------------------------------------------------------
61 |
--------------------------------------------------------------------------------
/torch_utils/ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.grid_sample` that
10 | supports arbitrarily high order gradients between the input and output.
11 | Only works on 2D images and assumes
12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13 |
14 | import warnings
15 | import torch
16 |
17 | # pylint: disable=redefined-builtin
18 | # pylint: disable=arguments-differ
19 | # pylint: disable=protected-access
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | enabled = False # Enable the custom op by setting this to true.
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | def grid_sample(input, grid):
28 | if _should_use_custom_op():
29 | return _GridSample2dForward.apply(input, grid)
30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def _should_use_custom_op():
35 | if not enabled:
36 | return False
37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
38 | return True
39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
40 | return False
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | class _GridSample2dForward(torch.autograd.Function):
45 | @staticmethod
46 | def forward(ctx, input, grid):
47 | assert input.ndim == 4
48 | assert grid.ndim == 4
49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
50 | ctx.save_for_backward(input, grid)
51 | return output
52 |
53 | @staticmethod
54 | def backward(ctx, grad_output):
55 | input, grid = ctx.saved_tensors
56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
57 | return grad_input, grad_grid
58 |
59 | #----------------------------------------------------------------------------
60 |
61 | class _GridSample2dBackward(torch.autograd.Function):
62 | @staticmethod
63 | def forward(ctx, grad_output, input, grid):
64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66 | ctx.save_for_backward(grid)
67 | return grad_input, grad_grid
68 |
69 | @staticmethod
70 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
71 | _ = grad2_grad_grid # unused
72 | grid, = ctx.saved_tensors
73 | grad2_grad_output = None
74 | grad2_input = None
75 | grad2_grid = None
76 |
77 | if ctx.needs_input_grad[0]:
78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79 |
80 | assert not ctx.needs_input_grad[2]
81 | return grad2_grad_output, grad2_input, grad2_grid
82 |
83 | #----------------------------------------------------------------------------
84 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "upfirdn2d.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17 | {
18 | // Validate arguments.
19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29 |
30 | // Create output tensor.
31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37 |
38 | // Initialize CUDA kernel parameters.
39 | upfirdn2d_kernel_params p;
40 | p.x = x.data_ptr();
41 | p.f = f.data_ptr();
42 | p.y = y.data_ptr();
43 | p.up = make_int2(upx, upy);
44 | p.down = make_int2(downx, downy);
45 | p.pad0 = make_int2(padx0, pady0);
46 | p.flip = (flip) ? 1 : 0;
47 | p.gain = gain;
48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56 |
57 | // Choose CUDA kernel.
58 | upfirdn2d_kernel_spec spec;
59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60 | {
61 | spec = choose_upfirdn2d_kernel(p);
62 | });
63 |
64 | // Set looping options.
65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66 | p.loopMinor = spec.loopMinor;
67 | p.loopX = spec.loopX;
68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70 |
71 | // Compute grid size.
72 | dim3 blockSize, gridSize;
73 | if (spec.tileOutW < 0) // large
74 | {
75 | blockSize = dim3(4, 32, 1);
76 | gridSize = dim3(
77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79 | p.launchMajor);
80 | }
81 | else // small
82 | {
83 | blockSize = dim3(256, 1, 1);
84 | gridSize = dim3(
85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87 | p.launchMajor);
88 | }
89 |
90 | // Launch CUDA kernel.
91 | void* args[] = {&p};
92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93 | return y;
94 | }
95 |
96 | //------------------------------------------------------------------------
97 |
98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99 | {
100 | m.def("upfirdn2d", &upfirdn2d);
101 | }
102 |
103 | //------------------------------------------------------------------------
104 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Facilities for pickling Python code alongside other data.
10 |
11 | The pickled code is automatically imported into a separate Python module
12 | during unpickling. This way, any previously exported pickles will remain
13 | usable even if the original code is no longer available, or if the current
14 | version of the code is not consistent with what was originally pickled."""
15 |
16 | import sys
17 | import pickle
18 | import io
19 | import inspect
20 | import copy
21 | import uuid
22 | import types
23 | import dnnlib
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | _version = 6 # internal version number
28 | _decorators = set() # {decorator_class, ...}
29 | _import_hooks = [] # [hook_function, ...]
30 | _module_to_src_dict = dict() # {module: src, ...}
31 | _src_to_module_dict = dict() # {src: module, ...}
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def persistent_class(orig_class):
36 | r"""Class decorator that extends a given class to save its source code
37 | when pickled.
38 |
39 | Example:
40 |
41 | from torch_utils import persistence
42 |
43 | @persistence.persistent_class
44 | class MyNetwork(torch.nn.Module):
45 | def __init__(self, num_inputs, num_outputs):
46 | super().__init__()
47 | self.fc = MyLayer(num_inputs, num_outputs)
48 | ...
49 |
50 | @persistence.persistent_class
51 | class MyLayer(torch.nn.Module):
52 | ...
53 |
54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55 | source code alongside other internal state (e.g., parameters, buffers,
56 | and submodules). This way, any previously exported pickle will remain
57 | usable even if the class definitions have been modified or are no
58 | longer available.
59 |
60 | The decorator saves the source code of the entire Python module
61 | containing the decorated class. It does *not* save the source code of
62 | any imported modules. Thus, the imported modules must be available
63 | during unpickling, also including `torch_utils.persistence` itself.
64 |
65 | It is ok to call functions defined in the same module from the
66 | decorated class. However, if the decorated class depends on other
67 | classes defined in the same module, they must be decorated as well.
68 | This is illustrated in the above example in the case of `MyLayer`.
69 |
70 | It is also possible to employ the decorator just-in-time before
71 | calling the constructor. For example:
72 |
73 | cls = MyLayer
74 | if want_to_make_it_persistent:
75 | cls = persistence.persistent_class(cls)
76 | layer = cls(num_inputs, num_outputs)
77 |
78 | As an additional feature, the decorator also keeps track of the
79 | arguments that were used to construct each instance of the decorated
80 | class. The arguments can be queried via `obj.init_args` and
81 | `obj.init_kwargs`, and they are automatically pickled alongside other
82 | object state. A typical use case is to first unpickle a previous
83 | instance of a persistent class, and then upgrade it to use the latest
84 | version of the source code:
85 |
86 | with open('old_pickle.pkl', 'rb') as f:
87 | old_net = pickle.load(f)
88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90 | """
91 | assert isinstance(orig_class, type)
92 | if is_persistent(orig_class):
93 | return orig_class
94 |
95 | assert orig_class.__module__ in sys.modules
96 | orig_module = sys.modules[orig_class.__module__]
97 | orig_module_src = _module_to_src(orig_module)
98 |
99 | class Decorator(orig_class):
100 | _orig_module_src = orig_module_src
101 | _orig_class_name = orig_class.__name__
102 |
103 | def __init__(self, *args, **kwargs):
104 | super().__init__(*args, **kwargs)
105 | self._init_args = copy.deepcopy(args)
106 | self._init_kwargs = copy.deepcopy(kwargs)
107 | assert orig_class.__name__ in orig_module.__dict__
108 | _check_pickleable(self.__reduce__())
109 |
110 | @property
111 | def init_args(self):
112 | return copy.deepcopy(self._init_args)
113 |
114 | @property
115 | def init_kwargs(self):
116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
117 |
118 | def __reduce__(self):
119 | fields = list(super().__reduce__())
120 | fields += [None] * max(3 - len(fields), 0)
121 | if fields[0] is not _reconstruct_persistent_obj:
122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
123 | fields[0] = _reconstruct_persistent_obj # reconstruct func
124 | fields[1] = (meta,) # reconstruct args
125 | fields[2] = None # state dict
126 | return tuple(fields)
127 |
128 | Decorator.__name__ = orig_class.__name__
129 | _decorators.add(Decorator)
130 | return Decorator
131 |
132 | #----------------------------------------------------------------------------
133 |
134 | def is_persistent(obj):
135 | r"""Test whether the given object or class is persistent, i.e.,
136 | whether it will save its source code when pickled.
137 | """
138 | try:
139 | if obj in _decorators:
140 | return True
141 | except TypeError:
142 | pass
143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144 |
145 | #----------------------------------------------------------------------------
146 |
147 | def import_hook(hook):
148 | r"""Register an import hook that is called whenever a persistent object
149 | is being unpickled. A typical use case is to patch the pickled source
150 | code to avoid errors and inconsistencies when the API of some imported
151 | module has changed.
152 |
153 | The hook should have the following signature:
154 |
155 | hook(meta) -> modified meta
156 |
157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158 |
159 | type: Type of the persistent object, e.g. `'class'`.
160 | version: Internal version number of `torch_utils.persistence`.
161 | module_src Original source code of the Python module.
162 | class_name: Class name in the original Python module.
163 | state: Internal state of the object.
164 |
165 | Example:
166 |
167 | @persistence.import_hook
168 | def wreck_my_network(meta):
169 | if meta.class_name == 'MyNetwork':
170 | print('MyNetwork is being imported. I will wreck it!')
171 | meta.module_src = meta.module_src.replace("True", "False")
172 | return meta
173 | """
174 | assert callable(hook)
175 | _import_hooks.append(hook)
176 |
177 | #----------------------------------------------------------------------------
178 |
179 | def _reconstruct_persistent_obj(meta):
180 | r"""Hook that is called internally by the `pickle` module to unpickle
181 | a persistent object.
182 | """
183 | meta = dnnlib.EasyDict(meta)
184 | meta.state = dnnlib.EasyDict(meta.state)
185 | for hook in _import_hooks:
186 | meta = hook(meta)
187 | assert meta is not None
188 |
189 | assert meta.version == _version
190 | module = _src_to_module(meta.module_src)
191 |
192 | assert meta.type == 'class'
193 | orig_class = module.__dict__[meta.class_name]
194 | decorator_class = persistent_class(orig_class)
195 | obj = decorator_class.__new__(decorator_class)
196 |
197 | setstate = getattr(obj, '__setstate__', None)
198 | if callable(setstate):
199 | setstate(meta.state) # pylint: disable=not-callable
200 | else:
201 | obj.__dict__.update(meta.state)
202 | return obj
203 |
204 | #----------------------------------------------------------------------------
205 |
206 | def _module_to_src(module):
207 | r"""Query the source code of a given Python module.
208 | """
209 | src = _module_to_src_dict.get(module, None)
210 | if src is None:
211 | src = inspect.getsource(module)
212 | _module_to_src_dict[module] = src
213 | _src_to_module_dict[src] = module
214 | return src
215 |
216 | def _src_to_module(src):
217 | r"""Get or create a Python module for the given source code.
218 | """
219 | module = _src_to_module_dict.get(src, None)
220 | if module is None:
221 | module_name = "_imported_module_" + uuid.uuid4().hex
222 | module = types.ModuleType(module_name)
223 | sys.modules[module_name] = module
224 | _module_to_src_dict[module] = src
225 | _src_to_module_dict[src] = module
226 | exec(src, module.__dict__) # pylint: disable=exec-used
227 | return module
228 |
229 | #----------------------------------------------------------------------------
230 |
231 | def _check_pickleable(obj):
232 | r"""Check that the given object is pickleable, raising an exception if
233 | it is not. This function is expected to be considerably more efficient
234 | than actually pickling the object.
235 | """
236 | def recurse(obj):
237 | if isinstance(obj, (list, tuple, set)):
238 | return [recurse(x) for x in obj]
239 | if isinstance(obj, dict):
240 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242 | return None # Python primitive types are pickleable.
243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
244 | return None # NumPy arrays and PyTorch tensors are pickleable.
245 | if is_persistent(obj):
246 | return None # Persistent objects are pickleable, by virtue of the constructor check.
247 | return obj
248 | with io.BytesIO() as f:
249 | pickle.dump(recurse(obj), f)
250 |
251 | #----------------------------------------------------------------------------
252 |
--------------------------------------------------------------------------------
/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Facilities for reporting and collecting training statistics across
10 | multiple processes and devices. The interface is designed to minimize
11 | synchronization overhead as well as the amount of boilerplate in user
12 | code."""
13 |
14 | import re
15 | import numpy as np
16 | import torch
17 | import dnnlib
18 |
19 | from . import misc
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
25 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
26 | _rank = 0 # Rank of the current process.
27 | _sync_device = None # Device to use for multiprocess communication. None = single-process.
28 | _sync_called = False # Has _sync() been called yet?
29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def init_multiprocessing(rank, sync_device):
35 | r"""Initializes `torch_utils.training_stats` for collecting statistics
36 | across multiple processes.
37 |
38 | This function must be called after
39 | `torch.distributed.init_process_group()` and before `Collector.update()`.
40 | The call is not necessary if multi-process collection is not needed.
41 |
42 | Args:
43 | rank: Rank of the current process.
44 | sync_device: PyTorch device to use for inter-process
45 | communication, or None to disable multi-process
46 | collection. Typically `torch.device('cuda', rank)`.
47 | """
48 | global _rank, _sync_device
49 | assert not _sync_called
50 | _rank = rank
51 | _sync_device = sync_device
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | @misc.profiled_function
56 | def report(name, value):
57 | r"""Broadcasts the given set of scalars to all interested instances of
58 | `Collector`, across device and process boundaries.
59 |
60 | This function is expected to be extremely cheap and can be safely
61 | called from anywhere in the training loop, loss function, or inside a
62 | `torch.nn.Module`.
63 |
64 | Warning: The current implementation expects the set of unique names to
65 | be consistent across processes. Please make sure that `report()` is
66 | called at least once for each unique name by each process, and in the
67 | same order. If a given process has no scalars to broadcast, it can do
68 | `report(name, [])` (empty list).
69 |
70 | Args:
71 | name: Arbitrary string specifying the name of the statistic.
72 | Averages are accumulated separately for each unique name.
73 | value: Arbitrary set of scalars. Can be a list, tuple,
74 | NumPy array, PyTorch tensor, or Python scalar.
75 |
76 | Returns:
77 | The same `value` that was passed in.
78 | """
79 | if name not in _counters:
80 | _counters[name] = dict()
81 |
82 | elems = torch.as_tensor(value)
83 | if elems.numel() == 0:
84 | return value
85 |
86 | elems = elems.detach().flatten().to(_reduce_dtype)
87 | moments = torch.stack([
88 | torch.ones_like(elems).sum(),
89 | elems.sum(),
90 | elems.square().sum(),
91 | ])
92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
93 | moments = moments.to(_counter_dtype)
94 |
95 | device = moments.device
96 | if device not in _counters[name]:
97 | _counters[name][device] = torch.zeros_like(moments)
98 | _counters[name][device].add_(moments)
99 | return value
100 |
101 | #----------------------------------------------------------------------------
102 |
103 | def report0(name, value):
104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
105 | but ignores any scalars provided by the other processes.
106 | See `report()` for further details.
107 | """
108 | report(name, value if _rank == 0 else [])
109 | return value
110 |
111 | #----------------------------------------------------------------------------
112 |
113 | class Collector:
114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
115 | computes their long-term averages (mean and standard deviation) over
116 | user-defined periods of time.
117 |
118 | The averages are first collected into internal counters that are not
119 | directly visible to the user. They are then copied to the user-visible
120 | state as a result of calling `update()` and can then be queried using
121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
122 | internal counters for the next round, so that the user-visible state
123 | effectively reflects averages collected between the last two calls to
124 | `update()`.
125 |
126 | Args:
127 | regex: Regular expression defining which statistics to
128 | collect. The default is to collect everything.
129 | keep_previous: Whether to retain the previous averages if no
130 | scalars were collected on a given round
131 | (default: True).
132 | """
133 | def __init__(self, regex='.*', keep_previous=True):
134 | self._regex = re.compile(regex)
135 | self._keep_previous = keep_previous
136 | self._cumulative = dict()
137 | self._moments = dict()
138 | self.update()
139 | self._moments.clear()
140 |
141 | def names(self):
142 | r"""Returns the names of all statistics broadcasted so far that
143 | match the regular expression specified at construction time.
144 | """
145 | return [name for name in _counters if self._regex.fullmatch(name)]
146 |
147 | def update(self):
148 | r"""Copies current values of the internal counters to the
149 | user-visible state and resets them for the next round.
150 |
151 | If `keep_previous=True` was specified at construction time, the
152 | operation is skipped for statistics that have received no scalars
153 | since the last update, retaining their previous averages.
154 |
155 | This method performs a number of GPU-to-CPU transfers and one
156 | `torch.distributed.all_reduce()`. It is intended to be called
157 | periodically in the main training loop, typically once every
158 | N training steps.
159 | """
160 | if not self._keep_previous:
161 | self._moments.clear()
162 | for name, cumulative in _sync(self.names()):
163 | if name not in self._cumulative:
164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
165 | delta = cumulative - self._cumulative[name]
166 | self._cumulative[name].copy_(cumulative)
167 | if float(delta[0]) != 0:
168 | self._moments[name] = delta
169 |
170 | def _get_delta(self, name):
171 | r"""Returns the raw moments that were accumulated for the given
172 | statistic between the last two calls to `update()`, or zero if
173 | no scalars were collected.
174 | """
175 | assert self._regex.fullmatch(name)
176 | if name not in self._moments:
177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
178 | return self._moments[name]
179 |
180 | def num(self, name):
181 | r"""Returns the number of scalars that were accumulated for the given
182 | statistic between the last two calls to `update()`, or zero if
183 | no scalars were collected.
184 | """
185 | delta = self._get_delta(name)
186 | return int(delta[0])
187 |
188 | def mean(self, name):
189 | r"""Returns the mean of the scalars that were accumulated for the
190 | given statistic between the last two calls to `update()`, or NaN if
191 | no scalars were collected.
192 | """
193 | delta = self._get_delta(name)
194 | if int(delta[0]) == 0:
195 | return float('nan')
196 | return float(delta[1] / delta[0])
197 |
198 | def std(self, name):
199 | r"""Returns the standard deviation of the scalars that were
200 | accumulated for the given statistic between the last two calls to
201 | `update()`, or NaN if no scalars were collected.
202 | """
203 | delta = self._get_delta(name)
204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
205 | return float('nan')
206 | if int(delta[0]) == 1:
207 | return float(0)
208 | mean = float(delta[1] / delta[0])
209 | raw_var = float(delta[2] / delta[0])
210 | return np.sqrt(max(raw_var - np.square(mean), 0))
211 |
212 | def as_dict(self):
213 | r"""Returns the averages accumulated between the last two calls to
214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
215 |
216 | dnnlib.EasyDict(
217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
218 | ...
219 | )
220 | """
221 | stats = dnnlib.EasyDict()
222 | for name in self.names():
223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
224 | return stats
225 |
226 | def __getitem__(self, name):
227 | r"""Convenience getter.
228 | `collector[name]` is a synonym for `collector.mean(name)`.
229 | """
230 | return self.mean(name)
231 |
232 | #----------------------------------------------------------------------------
233 |
234 | def _sync(names):
235 | r"""Synchronize the global cumulative counters across devices and
236 | processes. Called internally by `Collector.update()`.
237 | """
238 | if len(names) == 0:
239 | return []
240 | global _sync_called
241 | _sync_called = True
242 |
243 | # Collect deltas within current rank.
244 | deltas = []
245 | device = _sync_device if _sync_device is not None else torch.device('cpu')
246 | for name in names:
247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
248 | for counter in _counters[name].values():
249 | delta.add_(counter.to(device))
250 | counter.copy_(torch.zeros_like(counter))
251 | deltas.append(delta)
252 | deltas = torch.stack(deltas)
253 |
254 | # Sum deltas across ranks.
255 | if _sync_device is not None:
256 | torch.distributed.all_reduce(deltas)
257 |
258 | # Update cumulative values.
259 | deltas = deltas.cpu()
260 | for idx, name in enumerate(names):
261 | if name not in _cumulative:
262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
263 | _cumulative[name].add_(deltas[idx])
264 |
265 | # Return name-value pairs.
266 | return [(name, _cumulative[name]) for name in names]
267 |
268 | #----------------------------------------------------------------------------
269 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/training/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import numpy as np
11 | import zipfile
12 | import PIL.Image
13 | import json
14 | import torch
15 | import dnnlib
16 |
17 | try:
18 | import pyspng
19 | except ImportError:
20 | pyspng = None
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | class Dataset(torch.utils.data.Dataset):
25 | def __init__(self,
26 | name, # Name of the dataset.
27 | raw_shape, # Shape of the raw image data (NCHW).
28 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
29 | use_labels = False, # Enable conditioning labels? False = label dimension is zero.
30 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
31 | random_seed = 0, # Random seed to use when applying max_size.
32 | ):
33 | self._name = name
34 | self._raw_shape = list(raw_shape)
35 | self._use_labels = use_labels
36 | self._raw_labels = None
37 | self._label_shape = None
38 |
39 | # Apply max_size.
40 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
41 | if (max_size is not None) and (self._raw_idx.size > max_size):
42 | np.random.RandomState(random_seed).shuffle(self._raw_idx)
43 | self._raw_idx = np.sort(self._raw_idx[:max_size])
44 |
45 | # Apply xflip.
46 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
47 | if xflip:
48 | self._raw_idx = np.tile(self._raw_idx, 2)
49 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
50 |
51 | def _get_raw_labels(self):
52 | if self._raw_labels is None:
53 | self._raw_labels = self._load_raw_labels() if self._use_labels else None
54 | if self._raw_labels is None:
55 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
56 | assert isinstance(self._raw_labels, np.ndarray)
57 | assert self._raw_labels.shape[0] == self._raw_shape[0]
58 | assert self._raw_labels.dtype in [np.float32, np.int64]
59 | if self._raw_labels.dtype == np.int64:
60 | assert self._raw_labels.ndim == 1
61 | assert np.all(self._raw_labels >= 0)
62 | return self._raw_labels
63 |
64 | def close(self): # to be overridden by subclass
65 | pass
66 |
67 | def _load_raw_image(self, raw_idx): # to be overridden by subclass
68 | raise NotImplementedError
69 |
70 | def _load_raw_labels(self): # to be overridden by subclass
71 | raise NotImplementedError
72 |
73 | def __getstate__(self):
74 | return dict(self.__dict__, _raw_labels=None)
75 |
76 | def __del__(self):
77 | try:
78 | self.close()
79 | except:
80 | pass
81 |
82 | def __len__(self):
83 | return self._raw_idx.size
84 |
85 | def __getitem__(self, idx):
86 | image = self._load_raw_image(self._raw_idx[idx])
87 | assert isinstance(image, np.ndarray)
88 | assert list(image.shape) == self.image_shape
89 | assert image.dtype == np.uint8
90 | if self._xflip[idx]:
91 | assert image.ndim == 3 # CHW
92 | image = image[:, :, ::-1]
93 | return image.copy(), self.get_label(idx)
94 |
95 | def get_label(self, idx):
96 | label = self._get_raw_labels()[self._raw_idx[idx]]
97 | if label.dtype == np.int64:
98 | onehot = np.zeros(self.label_shape, dtype=np.float32)
99 | onehot[label] = 1
100 | label = onehot
101 | return label.copy()
102 |
103 | def get_details(self, idx):
104 | d = dnnlib.EasyDict()
105 | d.raw_idx = int(self._raw_idx[idx])
106 | d.xflip = (int(self._xflip[idx]) != 0)
107 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
108 | return d
109 |
110 | @property
111 | def name(self):
112 | return self._name
113 |
114 | @property
115 | def image_shape(self):
116 | return list(self._raw_shape[1:])
117 |
118 | @property
119 | def num_channels(self):
120 | assert len(self.image_shape) == 3 # CHW
121 | return self.image_shape[0]
122 |
123 | @property
124 | def resolution(self):
125 | assert len(self.image_shape) == 3 # CHW
126 | assert self.image_shape[1] == self.image_shape[2]
127 | return self.image_shape[1]
128 |
129 | @property
130 | def label_shape(self):
131 | if self._label_shape is None:
132 | raw_labels = self._get_raw_labels()
133 | if raw_labels.dtype == np.int64:
134 | self._label_shape = [int(np.max(raw_labels)) + 1]
135 | else:
136 | self._label_shape = raw_labels.shape[1:]
137 | return list(self._label_shape)
138 |
139 | @property
140 | def label_dim(self):
141 | assert len(self.label_shape) == 1
142 | return self.label_shape[0]
143 |
144 | @property
145 | def has_labels(self):
146 | return any(x != 0 for x in self.label_shape)
147 |
148 | @property
149 | def has_onehot_labels(self):
150 | return self._get_raw_labels().dtype == np.int64
151 |
152 | #----------------------------------------------------------------------------
153 |
154 | class ImageFolderDataset(Dataset):
155 | def __init__(self,
156 | path, # Path to directory or zip.
157 | resolution = None, # Ensure specific resolution, None = highest available.
158 | **super_kwargs, # Additional arguments for the Dataset base class.
159 | ):
160 | self._path = path
161 | self._zipfile = None
162 |
163 | if os.path.isdir(self._path):
164 | self._type = 'dir'
165 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
166 | elif self._file_ext(self._path) == '.zip':
167 | self._type = 'zip'
168 | self._all_fnames = set(self._get_zipfile().namelist())
169 | else:
170 | raise IOError('Path must point to a directory or zip')
171 |
172 | PIL.Image.init()
173 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
174 | if len(self._image_fnames) == 0:
175 | raise IOError('No image files found in the specified path')
176 |
177 | name = os.path.splitext(os.path.basename(self._path))[0]
178 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
179 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
180 | raise IOError('Image files do not match the specified resolution')
181 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
182 |
183 | @staticmethod
184 | def _file_ext(fname):
185 | return os.path.splitext(fname)[1].lower()
186 |
187 | def _get_zipfile(self):
188 | assert self._type == 'zip'
189 | if self._zipfile is None:
190 | self._zipfile = zipfile.ZipFile(self._path)
191 | return self._zipfile
192 |
193 | def _open_file(self, fname):
194 | if self._type == 'dir':
195 | return open(os.path.join(self._path, fname), 'rb')
196 | if self._type == 'zip':
197 | return self._get_zipfile().open(fname, 'r')
198 | return None
199 |
200 | def close(self):
201 | try:
202 | if self._zipfile is not None:
203 | self._zipfile.close()
204 | finally:
205 | self._zipfile = None
206 |
207 | def __getstate__(self):
208 | return dict(super().__getstate__(), _zipfile=None)
209 |
210 | def _load_raw_image(self, raw_idx):
211 | fname = self._image_fnames[raw_idx]
212 | with self._open_file(fname) as f:
213 | if pyspng is not None and self._file_ext(fname) == '.png':
214 | image = pyspng.load(f.read())
215 | else:
216 | image = np.array(PIL.Image.open(f))
217 | if image.ndim == 2:
218 | image = image[:, :, np.newaxis] # HW => HWC
219 | image = image.transpose(2, 0, 1) # HWC => CHW
220 | return image
221 |
222 | def _load_raw_labels(self):
223 | fname = 'dataset.json'
224 | if fname not in self._all_fnames:
225 | return None
226 | with self._open_file(fname) as f:
227 | labels = json.load(f)['labels']
228 | if labels is None:
229 | return None
230 | labels = dict(labels)
231 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
232 | labels = np.array(labels)
233 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
234 | return labels
235 |
236 | #----------------------------------------------------------------------------
237 |
--------------------------------------------------------------------------------
/training/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import numpy as np
10 | import torch
11 | from torch_utils import training_stats
12 | from torch_utils import misc
13 | from torch_utils.ops import conv2d_gradfix
14 |
15 | #----------------------------------------------------------------------------
16 |
17 | class Loss:
18 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
19 | raise NotImplementedError()
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | class StyleGAN2Loss(Loss):
24 | def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2):
25 | super().__init__()
26 | self.device = device
27 | self.G_mapping = G_mapping
28 | self.G_synthesis = G_synthesis
29 | self.D = D
30 | self.augment_pipe = augment_pipe
31 | self.style_mixing_prob = style_mixing_prob
32 | self.r1_gamma = r1_gamma
33 | self.pl_batch_shrink = pl_batch_shrink
34 | self.pl_decay = pl_decay
35 | self.pl_weight = pl_weight
36 | self.pl_mean = torch.zeros([], device=device)
37 |
38 | def run_G(self, z, c, sync):
39 | with misc.ddp_sync(self.G_mapping, sync):
40 | ws = self.G_mapping(z, c)
41 | if self.style_mixing_prob > 0:
42 | with torch.autograd.profiler.record_function('style_mixing'):
43 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
44 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
45 | ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
46 | with misc.ddp_sync(self.G_synthesis, sync):
47 | img = self.G_synthesis(ws)
48 | return img, ws
49 |
50 | def run_D(self, img, c, sync):
51 | if self.augment_pipe is not None:
52 | img = self.augment_pipe(img)
53 | with misc.ddp_sync(self.D, sync):
54 | logits = self.D(img, c)
55 | return logits
56 |
57 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
58 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
59 | do_Gmain = (phase in ['Gmain', 'Gboth'])
60 | do_Dmain = (phase in ['Dmain', 'Dboth'])
61 | do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
62 | do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
63 |
64 | # Gmain: Maximize logits for generated images.
65 | if do_Gmain:
66 | with torch.autograd.profiler.record_function('Gmain_forward'):
67 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl.
68 | gen_logits = self.run_D(gen_img, gen_c, sync=False)
69 | training_stats.report('Loss/scores/fake', gen_logits)
70 | training_stats.report('Loss/signs/fake', gen_logits.sign())
71 | loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
72 | training_stats.report('Loss/G/loss', loss_Gmain)
73 | with torch.autograd.profiler.record_function('Gmain_backward'):
74 | loss_Gmain.mean().mul(gain).backward()
75 |
76 | # Gpl: Apply path length regularization.
77 | if do_Gpl:
78 | with torch.autograd.profiler.record_function('Gpl_forward'):
79 | batch_size = gen_z.shape[0] // self.pl_batch_shrink
80 | gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
81 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
82 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
83 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
84 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
85 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
86 | self.pl_mean.copy_(pl_mean.detach())
87 | pl_penalty = (pl_lengths - pl_mean).square()
88 | training_stats.report('Loss/pl_penalty', pl_penalty)
89 | loss_Gpl = pl_penalty * self.pl_weight
90 | training_stats.report('Loss/G/reg', loss_Gpl)
91 | with torch.autograd.profiler.record_function('Gpl_backward'):
92 | (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
93 |
94 | # Dmain: Minimize logits for generated images.
95 | loss_Dgen = 0
96 | if do_Dmain:
97 | with torch.autograd.profiler.record_function('Dgen_forward'):
98 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
99 | gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
100 | training_stats.report('Loss/scores/fake', gen_logits)
101 | training_stats.report('Loss/signs/fake', gen_logits.sign())
102 | loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
103 | with torch.autograd.profiler.record_function('Dgen_backward'):
104 | loss_Dgen.mean().mul(gain).backward()
105 |
106 | # Dmain: Maximize logits for real images.
107 | # Dr1: Apply R1 regularization.
108 | if do_Dmain or do_Dr1:
109 | name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
110 | with torch.autograd.profiler.record_function(name + '_forward'):
111 | real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
112 | real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
113 | training_stats.report('Loss/scores/real', real_logits)
114 | training_stats.report('Loss/signs/real', real_logits.sign())
115 |
116 | loss_Dreal = 0
117 | if do_Dmain:
118 | loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
119 | training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
120 |
121 | loss_Dr1 = 0
122 | if do_Dr1:
123 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
124 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
125 | r1_penalty = r1_grads.square().sum([1,2,3])
126 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
127 | training_stats.report('Loss/r1_penalty', r1_penalty)
128 | training_stats.report('Loss/D/reg', loss_Dr1)
129 |
130 | with torch.autograd.profiler.record_function(name + '_backward'):
131 | (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
132 |
133 | #----------------------------------------------------------------------------
134 |
--------------------------------------------------------------------------------