├── LICENSE
├── README.md
├── _config.yaml
├── align.py
├── app.py
├── components
├── all.pt
├── b1024_conv0.pt
├── b1024_conv1.pt
├── b1024_torgb.pt
├── b128_conv0.pt
├── b128_conv1.pt
├── b128_torgb.pt
├── b16_conv0.pt
├── b16_conv1.pt
├── b16_torgb.pt
├── b256_conv0.pt
├── b256_conv1.pt
├── b256_torgb.pt
├── b32_conv0.pt
├── b32_conv1.pt
├── b32_torgb.pt
├── b4_conv1.pt
├── b4_torgb.pt
├── b512_conv0.pt
├── b512_conv1.pt
├── b512_torgb.pt
├── b64_conv0.pt
├── b64_conv1.pt
├── b64_torgb.pt
├── b8_conv0.pt
├── b8_conv1.pt
└── b8_torgb.pt
├── config.py
├── directions
├── age.npy
├── beard.npy
├── empirical_glasses.npy
├── eye_distance.npy
├── eye_eyebrow_distance.npy
├── eye_ratio.npy
├── eyes_open.npy
├── gender.npy
├── light.npy
├── lip_ratio.npy
├── mouth_open.npy
├── mouth_ratio.npy
├── nose_mouth_distance.npy
├── nose_ratio.npy
├── nose_tip.npy
├── pitch.npy
├── roll.npy
├── smile_styleganv2.npy
└── yaw.npy
├── dnnlib
├── __init__.py
└── util.py
├── extensions
├── canvas_to_masks.cpp
└── heatmap.cpp
├── mask.py
├── mask_refinement.py
├── qtutil.py
├── requirements.txt
├── resources.py
├── resources
└── iconS.png
├── s_presets.py
├── sample_image
└── input.jpg
├── scripts
├── __init__.py
├── extract_components.py
└── idempotent_blend.py
├── setup_cpp_ext.py
├── styleclip_mapper.py
├── styleclip_presets.py
├── stylegan_legacy.py
├── stylegan_networks.py
├── stylegan_project.py
├── stylegan_tune.py
├── synthesis.py
├── torch_utils
├── __init__.py
├── custom_ops.py
├── misc.py
├── ops
│ ├── __init__.py
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ └── upfirdn2d.py
├── persistence.py
└── training_stats.py
├── w_directions.py
└── widgets
├── __init__.py
├── editor.py
├── mask_painter.py
└── workspace.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 David Futschik
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Chunkmogrify: Real image inversion via Segments
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Teaser video with live editing sessions can be found here
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | This code demonstrates the ideas discussed in arXiv submission *Real Image Inversion via Segments*.
20 | http://arxiv.org/abs/2110.06269
21 | (David Futschik1, Michal Lukáč2, Eli Shechtman2, Daniel Sýkora1)
22 |
23 | 1Czech Technical University in Prague
24 | 2Adobe Research
25 |
26 |
27 | *Abstract:
28 | We present a simple, yet effective approach to editing
29 | real images via generative adversarial networks (GAN). Unlike previous
30 | techniques, that treat all editing tasks as an operation that affects pixel
31 | values in the entire image in our approach we cut up the image into a set of
32 | smaller segments. For those segments corresponding latent codes of a generative
33 | network can be estimated with greater accuracy due to the lower number of
34 | constraints. When codes are altered by the user the content in the image is
35 | manipulated locally while the rest of it remains unaffected. Thanks to this
36 | property the final edited image better retains the original structures and thus
37 | helps to preserve natural look.*
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 | ## What do I need?
54 | You will need a local machine with a relatively recent GPU - I wouldn't recommend trying
55 | Chunkmogrify with anything older than RTX 2080. It is technically possible to run even on CPU,
56 | but the operations become so slow that the user experience is not enjoyable.
57 |
58 | ## Quick startup guide
59 | Requirements:
60 | Python 3.7 or newer
61 |
62 | Note: If you are using Anaconda, I recommend creating a new environment to run this project.
63 | Packages installed with conda and pip often don't play together very nicely.
64 |
65 | Steps to be able to successfully run the project:
66 | 1) Clone or download the repository and open a terminal / Powershell instance in the directory.
67 | 2) Install the required python packages by running `pip install -r requirements.txt`. This
68 | might take a while, since it will download a few packages which will be several hundred MBs of data.
69 | Some packages might need to compile their extensions (as well as this project itself), so a C++
70 | compiler needs to be present. On Linux, this is typically not an issue, but running on Windows might
71 | require Visual Studio and CUDA installations to successfully setup the project.
72 | 3) Run `python app.py`. When running for the first time, it will automatically download required
73 | resources, which are also several hundred megabytes. Progression of the download can be monitored
74 | in the command line window.
75 |
76 | To see if everything installed and configured properly, load up a photo and try running a projection
77 | step. If there are no errors, you are good to go.
78 |
79 |
80 | ### Possible problems:
81 | *Torch not compiled with CUDA enabled.*
82 | Run
83 | ```
84 | pip uninstall torch
85 | pip cache purge
86 | pip install torch -f https://download.pytorch.org/whl/torch_stable.html
87 | ```
88 |
89 | ## Explanation of usage
90 |
91 |
92 | Tutorial video: click below
93 |
94 |
95 |
96 |
97 |
98 | Open an image using `File -> Image from File`. There is a sample image provided to check
99 | functionality.
100 |
101 | Mask painting:
102 | Left click paints, right click unpaints. Mouse wheel controls the size of the brush.
103 |
104 | Projection:
105 | Input a number of steps (100 or 200 is ok, 500 is max before LR goes to 0 currently) and press
106 | `Projection Steps`. Wait until projection finishes, you can observe the global image view by choosing
107 | output mode `Projection Only` during this process. To fine-tune, you can perform a small number of
108 | `Pivotal Tuning` steps.
109 |
110 | Editing:
111 | To add an edit, click the double arrow down icon in the Attribute Editor on the left side. Choose
112 | the type of edit (W, S, Styleclip), the direction of the edit, and drag the sliders to change the
113 | currently masked region. Usually it's necessary to increase the `multiplier` before noticeable
114 | changes are reflected via the `direction` slider.
115 |
116 | Multiple different edits can be composed on top of each other at the same time. Their order
117 | is largely irrelevant. Currently in the default mode, only one region is being edited, and so
118 | all selected edits apply to the same region. If you would like to change the region, you can
119 | `Freeze` the current image, and perform a new projection, but you will lose the ability to change
120 | existing edits.
121 |
122 | To save the current image, click the `Save Current Image` button. If the `Unalign` checkbox is
123 | active, the program will attempt to compose the aligned face back into the original image. Saved
124 | images can be found in the `SavedImages` directory by default. This can be changed in `_config.yaml`.
125 |
126 | ## Keyboard shortcuts
127 |
128 | Current keyboard shortcuts include:
129 |
130 | Show/Hide mask :: Alt+M
131 | Toggle mask painting :: Alt+N
132 |
133 | ## W-space editing
134 | Source for some of the basic directions:
135 | (https://twitter.com/robertluxemburg/status/1207087801344372736)
136 |
137 | To add your own directions, save them in a numpy pickle format as a (num_ws, 512) or (1, 512)
138 | format and specify their path in `w_directions.py`.
139 |
140 | ## Style-space editing (S space edits)
141 | Source:
142 | StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation
143 | (https://arxiv.org/abs/2011.12799)
144 | (https://github.com/betterze/StyleSpace)
145 |
146 | The presets can be found in `s_presets.py`, some were taken directly from the paper, others
147 | I found by manual exploration. You can perform similar exploration by choosing the `Custom`
148 | preset once you have a projection.
149 |
150 | ## StyleCLIP editing
151 | Source:
152 | StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery
153 | (https://arxiv.org/abs/2103.17249)
154 | (https://github.com/orpatashnik/StyleCLIP)
155 |
156 | Pretrained models taken from (https://github.com/orpatashnik/StyleCLIP/blob/main/utils.py) and
157 | manually removed the decoder from the state dict, since it's not used and takes up majority of
158 | file size.
159 |
160 | ## PTI Optimization
161 | Source:
162 | Pivotal Tuning for Latent-based Editing of Real Images
163 | (https://arxiv.org/abs/2106.05744)
164 |
165 | This method allows you to match the target photo very closely, while retaining editing capacities.
166 |
167 | It's often good to run 30-50 iterations of PTI to get very close matching of the source image,
168 | which won't cause a very noticeable drop in the editing capabilities.
169 |
170 |
171 | ## Attribution
172 | This repository makes use of code provided by the various repositories linked above, plus
173 | additionally code from:
174 |
175 | styleganv2-ada-pytorch (https://github.com/NVlabs/stylegan2-ada-pytorch)
176 | poisson-image-editing (https://github.com/PPPW/poisson-image-editing) for optional support
177 | of idempotent blend (slow implementation of blending that only changes the masked part which
178 | can be accessed by uncommenting the option in `synthesis.py`)
179 |
180 | ## Citation
181 |
182 | If you find this code useful for your research, please cite the arXiv submission linked above.
183 |
--------------------------------------------------------------------------------
/_config.yaml:
--------------------------------------------------------------------------------
1 |
2 | # Initial w to optimize from. Can be set to ~ for None.
3 | initial_w: ~
4 | # Skip auto alignment of the images. Only use this if you already have aligned images.
5 | skip_alignment: False
6 |
7 | generator_path: resources/ffhq.pkl
8 | # generator_path: pti_resources/model_merkel.pt
9 | # Perform plain torch.load on the pkl, otherwise look for G_ema.
10 | generator_load_raw: False
11 | # Resolution of the trained generator.
12 | generator_native_resolution: [1024, 1024]
13 |
14 | # Default projection mode.
15 | projection_mode: w_projection
16 | # Projection mode arguments.
17 | projection_args:
18 | lr_init: 1.0e-1
19 | l2_loss_weight: 0
20 | l1_loss_weight: 0.
21 | noise_regularize_weight: 0. # 10000.
22 | mean_latent_loss_weight: 10.
23 | percept_downsample: 0.5
24 |
25 | # Set this to ~ [0.5 - 1] if you want faster projection at the cost of ui updates.
26 | minimum_projection_update_window: 0.1
27 |
28 | # Use this device for torch.
29 | device: cuda:0
30 | # Don't change this unless you want to do multimask descent.
31 | max_segments: 1
32 | # Save exported images here.
33 | export_directory: SavedImages
34 |
35 | # Skip loading some resources for high performance startup.
36 | ui_debug_run: False
37 | show_debug_menu: False
38 |
--------------------------------------------------------------------------------
/components/all.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/all.pt
--------------------------------------------------------------------------------
/components/b1024_conv0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b1024_conv0.pt
--------------------------------------------------------------------------------
/components/b1024_conv1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b1024_conv1.pt
--------------------------------------------------------------------------------
/components/b1024_torgb.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b1024_torgb.pt
--------------------------------------------------------------------------------
/components/b128_conv0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b128_conv0.pt
--------------------------------------------------------------------------------
/components/b128_conv1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b128_conv1.pt
--------------------------------------------------------------------------------
/components/b128_torgb.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b128_torgb.pt
--------------------------------------------------------------------------------
/components/b16_conv0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b16_conv0.pt
--------------------------------------------------------------------------------
/components/b16_conv1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b16_conv1.pt
--------------------------------------------------------------------------------
/components/b16_torgb.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b16_torgb.pt
--------------------------------------------------------------------------------
/components/b256_conv0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b256_conv0.pt
--------------------------------------------------------------------------------
/components/b256_conv1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b256_conv1.pt
--------------------------------------------------------------------------------
/components/b256_torgb.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b256_torgb.pt
--------------------------------------------------------------------------------
/components/b32_conv0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b32_conv0.pt
--------------------------------------------------------------------------------
/components/b32_conv1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b32_conv1.pt
--------------------------------------------------------------------------------
/components/b32_torgb.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b32_torgb.pt
--------------------------------------------------------------------------------
/components/b4_conv1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b4_conv1.pt
--------------------------------------------------------------------------------
/components/b4_torgb.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b4_torgb.pt
--------------------------------------------------------------------------------
/components/b512_conv0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b512_conv0.pt
--------------------------------------------------------------------------------
/components/b512_conv1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b512_conv1.pt
--------------------------------------------------------------------------------
/components/b512_torgb.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b512_torgb.pt
--------------------------------------------------------------------------------
/components/b64_conv0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b64_conv0.pt
--------------------------------------------------------------------------------
/components/b64_conv1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b64_conv1.pt
--------------------------------------------------------------------------------
/components/b64_torgb.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b64_torgb.pt
--------------------------------------------------------------------------------
/components/b8_conv0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b8_conv0.pt
--------------------------------------------------------------------------------
/components/b8_conv1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b8_conv1.pt
--------------------------------------------------------------------------------
/components/b8_torgb.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b8_torgb.pt
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | import yaml
7 |
8 | class dotdict(dict):
9 | __getattr__ = dict.__getitem__
10 | __setattr__ = dict.__setitem__
11 | __delattr__ = dict.__delitem__
12 |
13 | def has_set(self, attr):
14 | return attr in self and self[attr] is not None
15 |
16 | def _into_deep_dotdict(regular_dict):
17 | new_dict = dotdict(regular_dict)
18 | for k, v in regular_dict.items():
19 | if type(v) == dict:
20 | new_dict[k] = _into_deep_dotdict(v)
21 | return new_dict
22 |
23 | def _load_config(path):
24 | with open(path) as fs:
25 | loaded = yaml.safe_load(fs)
26 | return _into_deep_dotdict(loaded)
27 |
28 | config = None
29 | def global_config():
30 | global config
31 | if config is None:
32 | config = _load_config("_config.yaml")
33 | return config
--------------------------------------------------------------------------------
/directions/age.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/age.npy
--------------------------------------------------------------------------------
/directions/beard.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/beard.npy
--------------------------------------------------------------------------------
/directions/empirical_glasses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/empirical_glasses.npy
--------------------------------------------------------------------------------
/directions/eye_distance.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/eye_distance.npy
--------------------------------------------------------------------------------
/directions/eye_eyebrow_distance.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/eye_eyebrow_distance.npy
--------------------------------------------------------------------------------
/directions/eye_ratio.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/eye_ratio.npy
--------------------------------------------------------------------------------
/directions/eyes_open.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/eyes_open.npy
--------------------------------------------------------------------------------
/directions/gender.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/gender.npy
--------------------------------------------------------------------------------
/directions/light.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/light.npy
--------------------------------------------------------------------------------
/directions/lip_ratio.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/lip_ratio.npy
--------------------------------------------------------------------------------
/directions/mouth_open.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/mouth_open.npy
--------------------------------------------------------------------------------
/directions/mouth_ratio.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/mouth_ratio.npy
--------------------------------------------------------------------------------
/directions/nose_mouth_distance.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/nose_mouth_distance.npy
--------------------------------------------------------------------------------
/directions/nose_ratio.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/nose_ratio.npy
--------------------------------------------------------------------------------
/directions/nose_tip.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/nose_tip.npy
--------------------------------------------------------------------------------
/directions/pitch.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/pitch.npy
--------------------------------------------------------------------------------
/directions/roll.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/roll.npy
--------------------------------------------------------------------------------
/directions/smile_styleganv2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/smile_styleganv2.npy
--------------------------------------------------------------------------------
/directions/yaw.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/yaw.npy
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | from .util import EasyDict, make_cache_dir_path
10 |
--------------------------------------------------------------------------------
/extensions/canvas_to_masks.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Author: David Futschik
3 | Provided as part of the Chunkmogrify project, 2021.
4 | */
5 |
6 | #define PY_SSIZE_T_CLEAN
7 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
8 |
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 |
21 |
22 | static PyObject* canvas_to_masks(PyObject* self, PyObject* args, PyObject* kwargs) {
23 |
24 | const char* kwarg_names[] = { "canvas", "colors", "output_buffer", NULL };
25 |
26 | PyArrayObject* canvas = nullptr;
27 | PyArrayObject* colors = nullptr;
28 | PyArrayObject* output = nullptr;
29 | if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|$O!", const_cast(kwarg_names),
30 | &PyArray_Type, &canvas,
31 | &PyArray_Type, &colors,
32 | &PyArray_Type, &output)) {
33 | return NULL;
34 | }
35 |
36 | if (PyArray_NDIM(canvas) != 3 and PyArray_NDIM(colors) != 2) {
37 | PyErr_SetString(PyExc_ValueError, "a.ndim must be 3 and b.ndim must be 2.");
38 | return NULL;
39 | }
40 |
41 | int h, w, c, num_color;
42 | h = PyArray_DIM(canvas, 0);
43 | w = PyArray_DIM(canvas, 1);
44 | c = PyArray_DIM(canvas, 2);
45 | num_color = PyArray_DIM(colors, 0);
46 |
47 | int dtype_a, dtype_b;
48 | dtype_a = PyArray_TYPE(canvas);
49 | dtype_b = PyArray_TYPE(colors);
50 |
51 | if (c != PyArray_DIM(colors, 1)) {
52 | PyErr_SetString(PyExc_ValueError, "a.shape[2] != b.shape[1]");
53 | return NULL;
54 | }
55 |
56 | if (dtype_a != dtype_b or dtype_a != NPY_UINT8) {
57 | PyErr_SetString(PyExc_ValueError, "dtype of both arrays must be uint8.");
58 | return NULL;
59 | }
60 |
61 | bool output_wrong_array = [output, h, w, num_color]() {
62 | if (not output) return true;
63 | if (PyArray_TYPE(output) != NPY_FLOAT) return true;
64 | if (PyArray_NDIM(output) != 3) return true;
65 | if (PyArray_DIM(output, 0) != h) return true;
66 | if (PyArray_DIM(output, 1) != w) return true;
67 | if (PyArray_DIM(output, 2) != num_color) return true;
68 | return false;
69 | }();
70 |
71 | if (!output or output_wrong_array) {
72 | // alloc new
73 | auto descr = PyArray_DescrFromType(NPY_FLOAT);
74 | npy_intp dims[] = { h, w, num_color };
75 | output = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type, descr, /* nd */ 3, dims, NULL, NULL, 0, NULL);
76 | // std::cout << "Sideeffect" << std::endl;
77 | } else {
78 | // incref of output
79 | Py_INCREF(output);
80 | }
81 |
82 | #pragma omp parallel for
83 | for (int y = 0; y < h; y++) {
84 | #pragma omp parallel for
85 | for (int x = 0; x < w; x++) {
86 | uint8_t* px = (uint8_t*)PyArray_GETPTR2(canvas, y, x);
87 | int same_idx = 0;
88 | #pragma unroll(12)
89 | for (int k = 0; k < num_color; k++) {
90 | uint8_t* kolor = (uint8_t*)PyArray_GETPTR1(colors, k);
91 | #pragma unroll(4)
92 | for (int it = 0; it < c; it++) {
93 | if (kolor[it] != px[it]) {
94 | goto cont;
95 | }
96 | }
97 | same_idx = k;
98 | goto brk;
99 | cont: ;
100 | }
101 | brk:
102 | float* o = (float*)PyArray_GETPTR2(output, y, x);
103 | for (int it = 0; it < num_color; it++) {
104 | // set the channel of the correct mask
105 | o[it] = it == same_idx ? 1. : 0.;
106 | }
107 | }
108 | }
109 |
110 | return (PyObject*)output;
111 | }
112 |
113 |
114 | static PyMethodDef python_methods[] = {
115 | { "canvas_to_masks", (PyCFunction)canvas_to_masks, METH_VARARGS|METH_KEYWORDS, "convert canvas colors to mask stack" },
116 | { nullptr, nullptr, 0, nullptr }
117 | };
118 |
119 | static struct PyModuleDef python_module = {
120 | PyModuleDef_HEAD_INIT,
121 | "_C_canvas", // _C_canvas.
122 | nullptr, // documentation
123 | -1,
124 | python_methods
125 | };
126 |
127 | PyMODINIT_FUNC
128 | PyInit__C_canvas(void) {
129 | auto x = PyModule_Create(&python_module);
130 | import_array();
131 | return x;
132 | }
--------------------------------------------------------------------------------
/extensions/heatmap.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Author: David Futschik
3 | Provided as part of the Chunkmogrify project, 2021.
4 | */
5 |
6 | #define PY_SSIZE_T_CLEAN
7 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 |
20 |
21 | static PyObject* heatmap(PyObject* self, PyObject* args, PyObject* kwargs) {
22 |
23 | const char* kwarg_names[] = { "values", "vmin", "vmax", NULL };
24 |
25 | PyArrayObject* canvas = nullptr;
26 | double vmin = 0;
27 | double vmax = 0;
28 |
29 | if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!dd|", const_cast(kwarg_names), &PyArray_Type, &canvas, &vmin, &vmax)) {
30 | return NULL;
31 | }
32 |
33 | if (PyArray_NDIM(canvas) != 3 and PyArray_NDIM(canvas) != 2) {
34 | PyErr_SetString(PyExc_ValueError, "a.ndim must be 3 or 2.");
35 | return NULL;
36 | }
37 |
38 | int has_c = PyArray_NDIM(canvas) > 2;
39 |
40 | int h, w, c;
41 | h = PyArray_DIM(canvas, 0);
42 | w = PyArray_DIM(canvas, 1);
43 | if (has_c) {
44 | c = PyArray_DIM(canvas, 2);
45 | if (c != 1) {
46 | PyErr_SetString(PyExc_ValueError, "3rd dimensions must be 1 or None");
47 | return NULL;
48 | }
49 | }
50 | else {
51 | c = 1;
52 | }
53 |
54 | int dtype_a = PyArray_TYPE(canvas);
55 |
56 | if (dtype_a != NPY_FLOAT) {
57 | PyErr_SetString(PyExc_ValueError, "dtype of array must be float.");
58 | return NULL;
59 | }
60 |
61 | auto descr = PyArray_DescrFromType(NPY_UINT8);
62 | npy_intp dims[] = { h, w, 3 };
63 | PyArrayObject* output = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type, descr, /* nd */ 3, dims, NULL, NULL, 0, NULL);
64 |
65 | double dv = vmax - vmin;
66 |
67 | #pragma omp parallel for
68 | for (int y = 0; y < h; y++) {
69 | #pragma omp parallel for
70 | for (int x = 0; x < w; x++) {
71 | float* pxval = (float*)PyArray_GETPTR2(canvas, y, x);
72 | float v = *pxval;
73 | float r(1), g(1), b(1);
74 |
75 | if (v < (vmin + .25 * dv)) {
76 | r = 0;
77 | g = 4 * (v - vmin) / dv;
78 | } else if (v < (vmin + .5 * dv)) {
79 | r = 0;
80 | b = 1 + 4 * (vmin + .25 * dv - v) / dv;
81 | } else if (v < (vmin + .75 * dv)) {
82 | r = 4 * (v - vmin - .5 * dv) / dv;
83 | b = 0;
84 | } else {
85 | g = 1 + 4 * (vmin + .75 * dv - v) / dv;
86 | b = 0;
87 | }
88 |
89 | *(uint8_t*)PyArray_GETPTR3(output, y, x, 0) = r * 255;
90 | *(uint8_t*)PyArray_GETPTR3(output, y, x, 1) = g * 255;
91 | *(uint8_t*)PyArray_GETPTR3(output, y, x, 2) = b * 255;
92 | }
93 | }
94 |
95 | return (PyObject*)output;
96 | }
97 |
98 |
99 | static PyMethodDef python_methods[] = {
100 | { "heatmap", (PyCFunction)heatmap, METH_VARARGS|METH_KEYWORDS, "convert data into heatmap" },
101 | { nullptr, nullptr, 0, nullptr }
102 | };
103 |
104 | static struct PyModuleDef python_module = {
105 | PyModuleDef_HEAD_INIT,
106 | "_C_heatmap", // _C_canvas.
107 | nullptr, // documentation
108 | -1,
109 | python_methods
110 | };
111 |
112 | PyMODINIT_FUNC
113 | PyInit__C_heatmap(void) {
114 | auto x = PyModule_Create(&python_module);
115 | import_array();
116 | return x;
117 | }
--------------------------------------------------------------------------------
/mask.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | from typing import Any, Callable, List
7 |
8 | import os
9 | import re
10 | import PIL.Image
11 | import numpy as np
12 |
13 | # MASK
14 | # 1 = Optimize in this region
15 | # 0 = Do not optimize in this region
16 | class MaskState:
17 | def __init__(self, height, width, max_segments, output_fns: List[Callable[[np.ndarray], Any]]):
18 | self.output = output_fns
19 | self.h, self.w, self.c = height, width, max_segments
20 | self.np_buffer = np.zeros((height, width, max_segments), dtype=np.float)
21 | self.update_callbacks = []
22 | self.mask_version = 0
23 |
24 | def set_to(self, new_buffer: np.ndarray, **cb_kwargs):
25 | assert self.np_buffer.min() >= 0. and self.np_buffer.max() <= 1., \
26 | f"Mask has range [{self.np_buffer.min()}, {self.np_buffer.max()}]"
27 | buffer_c = new_buffer.shape[2]
28 | self.np_buffer[..., :min(self.c, buffer_c)] = new_buffer[..., :min(self.c, buffer_c)]
29 | self.mask_version += 1
30 | for c in self.update_callbacks:
31 | c(self.np_buffer, **cb_kwargs)
32 |
33 | def load_masks(self, source_dir):
34 | if not os.path.exists(source_dir):
35 | print(f"Could not load masks from {source_dir}")
36 | return
37 | ls = os.listdir(source_dir)
38 | found_idx = [re.match(r'^\d+', x) for x in ls]
39 | found_idx = [int(f.group()) for f in found_idx if f is not None]
40 | # 0 used to be the entire thing, in that case found_idx.remove(0)
41 | for idx in found_idx:
42 | # load the image
43 | m = np.array(PIL.Image.open(os.path.join(source_dir, f'{idx:02d}.png')).convert('L'))
44 | if idx >= self.np_buffer.shape[2]:
45 | print(f"Skipping mask {idx}")
46 | continue
47 | self.np_buffer[..., idx] = m / 255.
48 | self.mask_version += 1
49 | for c in self.update_callbacks:
50 | c(self.np_buffer)
51 |
52 | def save_masks(self, target_dir, painter, max_segments):
53 | # Get it from painter to include the RGB as it is in the app.
54 | rgb_masks, npy_masks = painter.get_volatile_masks()
55 | PIL.Image.fromarray(rgb_masks).save(os.path.join(target_dir, "rgb.png"))
56 | # max_segments + 1 because "empty" is the 0th mask.
57 | PIL.Image.fromarray((npy_masks[:, :, 0] * 255).astype(np.uint8)).save(os.path.join(target_dir, f"all.png"))
58 | for i in range(1, min(max_segments + 1, npy_masks.shape[2])):
59 | PIL.Image.fromarray((npy_masks[:, :, i] * 255).astype(np.uint8)).save(os.path.join(target_dir, f"{i - 1:02d}.png"))
60 |
61 | def get_mask_version(self):
62 | return self.mask_version
63 |
64 | def numpy_buffer(self):
65 | assert self.np_buffer.min() >= 0. and self.np_buffer.max() <= 1., \
66 | f"Mask has range [{self.np_buffer.min()}, {self.np_buffer.max()}]"
67 | return self.np_buffer
68 |
--------------------------------------------------------------------------------
/mask_refinement.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | import math
7 | import torch
8 | import numpy as np
9 | from torch.nn import functional as F
10 |
11 | def torch_grad(x):
12 | a = torch.tensor([[-1, 0, 1]], dtype=torch.float32, device=x.device, requires_grad=False).view((1, 1, 1, 3))
13 | b = torch.tensor([[-1, 0, 1]], dtype=torch.float32, device=x.device, requires_grad=False).view((1, 1, 3, 1))
14 | G_x = F.conv2d(x, a) / 2
15 | G_y = F.conv2d(x, b) / 2
16 | G_x = F.pad(G_x, (1, 1, 0, 0), 'constant', 0.)
17 | G_y = F.pad(G_y, (0, 0, 1, 1), 'constant', 0.)
18 | return [G_x, G_y]
19 |
20 | def torch_normgrad_curv(x):
21 | G_x, G_y = torch_grad(x)
22 | G = torch.sqrt(torch.pow(G_x,2) + torch.pow(G_y,2))
23 | # div
24 | div = torch_grad(G_x / (G + 1e-8) )[0] + torch_grad(G_y / (G + 1e-8))[1]
25 | return G, div
26 |
27 | def contrast_magnify(x, min=0, max=64, fromval=0., toval=255.):
28 | mul_by = (toval - min) / max
29 | x = ((x - min) * mul_by).clip(fromval, toval)
30 | return x
31 |
32 | def mask_refine(mask, image1, image2, dt_A=0.001, dt_B=0.1, iters=300):
33 |
34 | S = (mask - 0.5)
35 | P = S[:, 0:1, :, :]
36 |
37 | Pg = torch_normgrad_curv(P)[0]
38 |
39 | I_1_mean = ((Pg * image1)).sum(dim=(2,3)) / Pg.sum()
40 | I_2_mean = ((Pg * image2)).sum(dim=(2,3)) / Pg.sum()
41 |
42 | assert I_1_mean.shape == (1, 3), "Mean wrong shape"
43 | I_1 = image1 - I_1_mean[:, :, None, None]
44 | I_2 = image2 - I_2_mean[:, :, None, None]
45 | Fn = I_1 - I_2
46 | Fn = Fn.norm(dim=1) / math.sqrt(12)
47 |
48 | P = P * 255
49 | Fn = Fn * 255
50 |
51 | Fn = contrast_magnify(Fn)
52 |
53 | with torch.no_grad():
54 | for _ in range(iters):
55 | ng, div = torch_normgrad_curv(P)
56 | P -= dt_A * Fn * ng - dt_B * ng * div
57 |
58 | P = torch.where(P < 0, 0., 1.)
59 |
60 | new_mask = P
61 | return new_mask
62 |
63 | if __name__ == "__main__":
64 | import PIL.Image
65 |
66 | i1 = PIL.Image.open("_I1_.png")
67 | i2 = PIL.Image.open("_I2_.png")
68 | s = PIL.Image.open("_P_.png")
69 |
70 | i1 = (torch.tensor( np.array(i1).astype(np.float), dtype=torch.float32, device='cuda:0' ) / 127.5).permute((2,0,1)).unsqueeze(0)
71 | i2 = (torch.tensor( np.array(i2).astype(np.float), dtype=torch.float32, device='cuda:0' ) / 127.5).permute((2,0,1)).unsqueeze(0)
72 | s = (torch.tensor( np.array(s).astype(np.float), dtype=torch.float32, device='cuda:0' ) / 255.).unsqueeze(0).unsqueeze(0)
73 |
74 | m = mask_refine(s, i1 - 1, i2 - 1)
75 |
76 | d = m.cpu().numpy()[0].transpose((1,2,0)).__mul__(255.).astype(np.uint8)
77 | PIL.Image.fromarray(d[:, :, 0]).save("maskrefine.png")
78 |
--------------------------------------------------------------------------------
/qtutil.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | import os
7 | import re
8 | import numpy as np
9 | import PIL.Image as Image
10 | from enum import Enum
11 | from threading import Thread
12 | from threading import Timer
13 | from time import time, perf_counter
14 |
15 | from PyQt5.QtWidgets import *
16 | from PyQt5.QtGui import *
17 | from PyQt5.QtCore import *
18 |
19 | class dotdict(dict):
20 | __getattr__ = dict.__getitem__
21 | __setattr__ = dict.__setitem__
22 | __delattr__ = dict.__delitem__
23 |
24 | def has_set(self, attr):
25 | return attr in self and self[attr] is not None
26 |
27 | class MouseButton(Enum):
28 | LEFT = 1
29 | RIGHT = 2
30 | MIDDLE = 4
31 | SIDE_1 = 8
32 | SIDE_2 = 16
33 |
34 | def QHLine():
35 | w = QFrame()
36 | w.setFrameShape(QFrame.HLine)
37 | w.setFrameShadow(QFrame.Sunken)
38 | return w
39 |
40 | def QVLine():
41 | w = QFrame()
42 | w.setFrameShape(QFrame.VLine)
43 | w.setFrameShadow(QFrame.Raised)
44 | return w
45 |
46 | def qpixmap_loader(path: str):
47 | if not os.path.exists(path):
48 | print(f"Warning: {path} does not exist, but queued for load as pixmap")
49 | else:
50 | print(f"Loading {path}")
51 | return QPixmap(path)
52 |
53 | def npy_loader(path: str):
54 | if not os.path.exists(path):
55 | raise RuntimeError(f"{path} does not exist")
56 | return np.array(Image.open(path))
57 |
58 | def make_dirs_if_not_exists(path):
59 | if not os.path.exists(path) or not os.path.isdir(path):
60 | os.makedirs(path)
61 |
62 | def export_image(prefix, npy_image):
63 | if not os.path.exists(prefix):
64 | os.mkdir(prefix)
65 | ls = os.listdir(prefix)
66 | prev_imgs = [re.match(r'^\d+', x) for x in ls]
67 | prev_imgs = [int(x.group()) for x in prev_imgs if x is not None]
68 | cur_id = max(prev_imgs, default=-1) + 1
69 | fname = os.path.join(prefix, f'{cur_id:05d}.png')
70 | Image.fromarray(npy_image).save(fname)
71 | return fname
72 |
73 | def image_save(npy_image, path):
74 | Image.fromarray(npy_image).save(path)
75 |
76 | def qim2np(qimage, swapaxes=True):
77 | # assumes BGRA if swapaxes is True
78 | # assert(qimage.format() == QImage.Format_8888)
79 | bytes_per_pxl = 4
80 | if qimage.format() == QImage.Format_Grayscale8:
81 | bytes_per_pxl = 1
82 | qimage.__array_interface__ = {
83 | 'shape': (qimage.height(), qimage.width(), bytes_per_pxl),
84 | 'typestr': "|u1",
85 | 'data': (int(qimage.bits()), False),
86 | 'version': 3
87 | }
88 | npim = np.asarray(qimage)
89 | # npim is now BGRA, cast it into RGBA
90 | if bytes_per_pxl == 4 and swapaxes:
91 | npim = npim[...,(2,1,0,3)]
92 | return npim
93 |
94 | def np2qim(nparr, fmt_select='auto', do_copy=True):
95 | h, w = nparr.shape[0:2]
96 | # the copy is required unless reference is desired!
97 | if do_copy:
98 | nparr = nparr.astype('uint8').copy()
99 | else:
100 | nparr = np.ascontiguousarray(nparr.astype('uint8'))
101 | if nparr.ndim == 2:
102 | # or Alpha8 when it has no 3rd dimension (masks)
103 | fmt = QImage.Format_Grayscale8 #Alpha8
104 | elif fmt_select != 'auto':
105 | fmt = fmt_select
106 | else:
107 | fmt = {3: QImage.Format_RGB888, 4: QImage.Format_RGBA8888}[nparr.shape[2]]
108 | qim = QImage(nparr, w, h, nparr.strides[0], fmt) # _Premultiplied
109 | return qim
110 |
111 | def destroy_layout(layout):
112 | if layout is not None:
113 | while layout.count():
114 | item = layout.takeAt(0)
115 | widget = item.widget()
116 | if widget is not None:
117 | widget.deleteLater()
118 | else:
119 | destroy_layout(item.layout())
120 |
121 | class NotifyWait(QObject):
122 | acquire = pyqtSignal(str)
123 | release = pyqtSignal()
124 |
125 | def __init__(self, parent=None) -> None:
126 | super().__init__(parent=parent)
127 | self.msg = ""
128 | d = QProgressDialog(self.msg, "Cancel", 0, 0, parent=get_default_parent())
129 | d.setCancelButton(None)
130 | d.setWindowTitle("Working")
131 | d.setWindowFlags(d.windowFlags() & ~Qt.WindowCloseButtonHint)
132 | self.d = d
133 | self.d.cancel()
134 |
135 | def _show(self):
136 | self.d.setLabelText(self.msg)
137 | self.d.exec_()
138 |
139 | def _hide(self):
140 | self.d.done(0)
141 |
142 |
143 | class RaiseError(QObject):
144 | raiseme = pyqtSignal(str, str, bool) # is_fatal
145 |
146 | def __init__(self, parent=None) -> None:
147 | super().__init__(parent=parent)
148 |
149 | def _show(self, where, what, is_fatal):
150 | msg = QMessageBox(self.parent())
151 | msg.setText(where)
152 | msg.setInformativeText(what)
153 | msg.setWindowTitle("Error")
154 | msg.setIcon(QMessageBox.Critical)
155 | msgresizer = QSpacerItem(500, 0, QSizePolicy.Minimum, QSizePolicy.Expanding)
156 | msg.layout().addItem(msgresizer, msg.layout().rowCount(), 0, 1, msg.layout().columnCount())
157 | msg.exec_()
158 | if is_fatal:
159 | exit(1)
160 |
161 | _global_parent = None
162 | _notify_wait = None
163 | _global_error = None
164 | def set_default_parent(parent):
165 | global _global_parent
166 | global _notify_wait
167 | global _global_error
168 | _global_parent = parent
169 | _notify_wait = NotifyWait(_global_parent)
170 | _global_error = RaiseError(_global_parent)
171 |
172 | def s(msg):
173 | _notify_wait.msg = msg
174 | _notify_wait._show()
175 | def e():
176 | _notify_wait._hide()
177 | _notify_wait.acquire.connect(s)
178 | _notify_wait.release.connect(e)
179 |
180 | def e(where, what, is_fatal):
181 | _global_error._show(where, what, is_fatal)
182 | _global_error.raiseme.connect(e)
183 |
184 | def get_default_parent():
185 | global _global_parent
186 | return _global_parent
187 | def get_notify_wait():
188 | global _notify_wait
189 | return _notify_wait
190 | def get_global_error():
191 | global _global_error
192 | return _global_error
193 |
194 | def notify_user_wait(message):
195 | d = QProgressDialog(message, "Cancel", 0, 0, parent=get_default_parent())
196 | d.setCancelButton(None)
197 | d.setWindowTitle("Working")
198 | d.setWindowFlags(d.windowFlags() & ~Qt.WindowCloseButtonHint)
199 | return d
200 |
201 |
202 | def notify_user_error(where, what, parent=None):
203 | msg = QMessageBox(parent)
204 | msg.setText(where)
205 | msg.setInformativeText(what)
206 | msg.setWindowTitle("Error")
207 | msg.setIcon(QMessageBox.Critical)
208 | msgresizer = QSpacerItem(500, 0, QSizePolicy.Minimum, QSizePolicy.Expanding)
209 | msg.layout().addItem(msgresizer, msg.layout().rowCount(), 0, 1, msg.layout().columnCount())
210 | msg.exec_()
211 |
212 | # Really only works when called from main thread.
213 | def execute_with_wait(message, fn, *args, **kwargs):
214 | d = notify_user_wait(message)
215 |
216 | def fn_impl():
217 | fn(*args, **kwargs)
218 | d.done(0)
219 | t = Thread(target=fn_impl)
220 | t.start()
221 | d.exec_()
222 |
223 | class NowOrDelayTimer:
224 | def __init__(self, interval):
225 | self.interval = interval
226 | self.last_ok_time = 0
227 | self.timer = None
228 |
229 | def update(self, do_run):
230 | if self.last_ok_time < time() - self.interval:
231 | self.last_ok_time = time()
232 | do_run()
233 | # cancel pending update
234 | if self.timer: self.timer.cancel()
235 | else:
236 | # cancel last pending update
237 | if self.timer: self.timer.cancel()
238 | def closure():
239 | self.update(do_run)
240 | self.timer = Timer(self.last_ok_time + self.interval - time(), closure)
241 | self.timer.start()
242 |
243 | class MeasureTime:
244 | def __init__(self, name, disable=False):
245 | self.name = name
246 | self.start = None
247 | self.disable = disable
248 |
249 | def __enter__(self):
250 | if not self.disable:
251 | self.start = perf_counter()
252 |
253 | def __exit__(self, type, value, traceback):
254 | if not self.disable:
255 | end = perf_counter()
256 | print(f'{self.name}: {end-self.start:.3f}')
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | certifi==2021.5.30
2 | charset-normalizer==2.0.6
3 | click==8.0.1
4 | cycler==0.10.0
5 | dlib==19.22.1
6 | idna==3.2
7 | imageio==2.9.0
8 | importlib-metadata==4.8.1
9 | kiwisolver==1.3.2
10 | lpips==0.1.4
11 | matplotlib==3.4.3
12 | networkx==2.6.3
13 | numpy==1.21.2
14 | opencv-python-headless==4.5.3.56
15 | Pillow==8.3.2
16 | pyparsing==2.4.7
17 | PyQt5==5.15.4
18 | PyQt5-Qt5==5.15.2
19 | PyQt5-sip==12.9.0
20 | python-dateutil==2.8.2
21 | PyWavelets==1.1.1
22 | PyYAML==5.4.1
23 | requests==2.26.0
24 | scikit-image==0.18.3
25 | scipy==1.7.1
26 | six==1.16.0
27 | tifffile==2021.8.30
28 | tqdm==4.62.3
29 | typing-extensions==3.10.0.2
30 | urllib3==1.26.7
31 | zipp==3.6.0
32 | ninja
33 | # package location
34 | --find-links https://download.pytorch.org/whl/torch_stable.html
35 | torch==1.9.1+cu102
36 | torchvision==0.10.1
37 |
38 |
39 |
--------------------------------------------------------------------------------
/resources.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 |
7 | import sys
8 | import os
9 | import requests
10 |
11 | from qtutil import make_dirs_if_not_exists
12 |
13 | resource_list = {
14 | 'ffhq_styleganv2': 'resources/ffhq.pkl',
15 | 'dlib_align_data': 'resources/shape_predictor_68_face_landmarks.dat',
16 | 'styleclip_afro': 'resources/styleclip/afro.pt',
17 | 'styleclip_bobcut':'resources/styleclip/bobcut.pt',
18 | 'styleclip_curly_hair': 'resources/styleclip/curly_hair.pt',
19 | 'styleclip_bowlcut': 'resources/styleclip/bowlcut.pt',
20 | 'styleclip_mohawk': 'resources/styleclip/mohawk.pt',
21 | }
22 |
23 | resource_sources = {
24 | 'ffhq_styleganv2' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/ffhq.pkl',
25 | 'dlib_align_data' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/shape_predictor_68_face_landmarks.dat',
26 | 'styleclip_afro' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/styleclip/afro.pt',
27 | 'styleclip_bobcut' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/styleclip/bobcut.pt',
28 | 'styleclip_curly_hair': 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/styleclip/curly_hair.pt',
29 | 'styleclip_bowlcut' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/styleclip/bowlcut.pt',
30 | 'styleclip_mohawk' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/styleclip/mohawk.pt',
31 | }
32 |
33 | def download_url(url, store_at):
34 | dir = os.path.dirname(store_at)
35 | make_dirs_if_not_exists(dir)
36 |
37 | # open in binary mode
38 | with open(store_at, "wb") as file:
39 | # get request
40 | response = requests.get(url, stream=True)
41 | total_size = response.headers.get('content-length')
42 | downloaded = 0
43 | total_size = int(total_size)
44 | for data in response.iter_content(chunk_size=8192):
45 | downloaded += len(data)
46 | file.write(data)
47 | done = int(50 * downloaded / total_size)
48 | sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {100*downloaded/total_size:.02f}%")
49 | sys.stdout.flush()
50 | print()
51 |
52 | def check_and_download_all():
53 | for resource, storage in resource_list.items():
54 | if not os.path.exists(storage):
55 | print(f"{resource} not found at {storage}, downloading..")
56 | download_url(resource_sources[resource], storage)
57 |
--------------------------------------------------------------------------------
/resources/iconS.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/resources/iconS.png
--------------------------------------------------------------------------------
/s_presets.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | def limits():
7 | # Generated from 18 layer stylegan (x1024)
8 | # Was done on the fly previously, but this is easier for UI.
9 | # If the architecture changes, just run any input up to S space and fill in the input sizes.
10 | num_s = [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 256, 128, 128, 128, 64, 64, 64, 32, 32]
11 | return {
12 | 'layer': len(num_s),
13 | 'channel': num_s
14 | }
15 |
16 | def known_presets():
17 | return {
18 | 'gaze': (9, 409),
19 | 'smile': (6, 259),
20 | 'eyebrows_1': (8, 28),
21 | 'eyebrows_2': (12, 455),
22 | 'eyebrows_3': (6, 179),
23 | 'hair_color': (12, 266),
24 | 'fringe': (6, 285),
25 | 'lipstick': (15, 45),
26 | 'eye_makeup': (12, 414),
27 | 'eye_roll': (14, 239),
28 | 'asian_eyes': (9, 376),
29 | 'gray_hair': (6, 364),
30 | 'eye_size': (12, 110),
31 | 'goatee': (9, 421),
32 | 'fat': (6, 104),
33 | 'gender': (6, 128),
34 | 'chin': (6, 131),
35 | 'double_chin': (6, 144),
36 | 'sideburns': (6, 234),
37 | 'forehead_hair': (6, 322),
38 | 'curly_hair': (6, 364),
39 | 'nose_up_down': (9, 86),
40 | 'eye_wide': (9, 63),
41 | 'gender_2': (9, 6),
42 | 'demon_eyes': (14, 319),
43 | 'sunken_eyes': (14, 380),
44 | 'pupil_1': (14, 414),
45 | 'pupil_2': (14, 419),
46 | }
--------------------------------------------------------------------------------
/sample_image/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/sample_image/input.jpg
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/extract_components.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | import os
7 | import torch
8 | from synthesis import init_gan
9 | from qtutil import make_dirs_if_not_exists
10 |
11 | init_gan()
12 | from synthesis import gan
13 |
14 | keys = gan.g.synthesis._modules.keys()
15 |
16 | export_to = 'components'
17 | make_dirs_if_not_exists(export_to)
18 | all_w = []
19 | for key in keys:
20 |
21 | for subkey in gan.g.synthesis._modules[key]._modules.keys():
22 | w = gan.g.synthesis._modules[key]._modules[subkey].affine._parameters['weight']
23 | all_w.append(w)
24 | eigs = torch.svd(w).V.cpu()
25 | torch.save(eigs, os.path.join(export_to, f'{key}_{subkey}.pt'))
26 |
27 | all_w = torch.cat(all_w, dim=0)
28 | eigs = torch.svd(all_w).V.cpu()
29 | torch.save(eigs, os.path.join(export_to, f'all.pt'))
30 |
--------------------------------------------------------------------------------
/scripts/idempotent_blend.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import scipy.sparse
4 | from scipy.sparse.linalg import spsolve
5 |
6 |
7 | def laplacian_matrix(n, m):
8 | """Generate the Poisson matrix.
9 | Refer to:
10 | https://en.wikipedia.org/wiki/Discrete_Poisson_equation
11 | Note: it's the transpose of the wiki's matrix
12 | """
13 | mat_D = scipy.sparse.lil_matrix((m, m))
14 | mat_D.setdiag(-1, -1)
15 | mat_D.setdiag(4)
16 | mat_D.setdiag(-1, 1)
17 |
18 | mat_A = scipy.sparse.block_diag([mat_D] * n).tolil()
19 |
20 | mat_A.setdiag(-1, 1*m)
21 | mat_A.setdiag(-1, -1*m)
22 |
23 | return mat_A
24 |
25 | def poisson_edit(source, target, mask, offset):
26 | """The poisson blending function.
27 | Refer to:
28 | Perez et. al., "Poisson Image Editing", 2003.
29 | """
30 |
31 | # Assume:
32 | # target is not smaller than source.
33 | # shape of mask is same as shape of target.
34 | y_max, x_max = target.shape[:-1]
35 | y_min, x_min = 0, 0
36 |
37 | x_range = x_max - x_min
38 | y_range = y_max - y_min
39 |
40 | M = np.float32([[1,0,offset[0]],[0,1,offset[1]]])
41 | source = cv2.warpAffine(source,M,(x_range,y_range))
42 |
43 | mask = mask[y_min:y_max, x_min:x_max]
44 | mask[mask != 0] = 1
45 |
46 | mat_A = laplacian_matrix(y_range, x_range)
47 |
48 | # for \Delta g
49 | laplacian = mat_A.tocsc()
50 |
51 | # set the region outside the mask to identity
52 | for y in range(1, y_range - 1):
53 | for x in range(1, x_range - 1):
54 | if mask[y, x] == 0:
55 | k = x + y * x_range
56 | mat_A[k, k] = 1
57 | mat_A[k, k + 1] = 0
58 | mat_A[k, k - 1] = 0
59 | mat_A[k, k + x_range] = 0
60 | mat_A[k, k - x_range] = 0
61 |
62 | mat_A = mat_A.tocsc()
63 |
64 | mask_flat = mask.flatten()
65 | for channel in range(source.shape[2]):
66 | source_flat = source[y_min:y_max, x_min:x_max, channel].flatten()
67 | target_flat = target[y_min:y_max, x_min:x_max, channel].flatten()
68 |
69 | # inside the mask:
70 | # \Delta f = div v = \Delta g
71 | alpha = 1
72 | mat_b = laplacian.dot(source_flat)*alpha
73 |
74 | # outside the mask:
75 | # f = t
76 | mat_b[mask_flat==0] = target_flat[mask_flat==0]
77 |
78 | x = spsolve(mat_A, mat_b)
79 | x = x.reshape((y_range, x_range))
80 | x[x > 255] = 255
81 | x[x < 0] = 0
82 | x = x.astype('uint8')
83 |
84 | target[y_min:y_max, x_min:x_max, channel] = x
85 |
86 | return target
--------------------------------------------------------------------------------
/setup_cpp_ext.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | import platform
7 | import setuptools
8 | import numpy as np
9 | from setuptools import sandbox
10 |
11 | platform_specific_flags = []
12 | if platform.system() == "Windows":
13 | platform_specific_flags += ["/permissive-", "/Ox", "/std:c++11"]
14 | else:
15 | platform_specific_flags += ["-O3", "--std=c++11"]
16 |
17 | ext_modules = [
18 | setuptools.Extension('_C_canvas',
19 | sources=['extensions/canvas_to_masks.cpp'],
20 | include_dirs=[np.get_include()],
21 | extra_compile_args=platform_specific_flags,
22 | language='c++'),
23 | setuptools.Extension('_C_heatmap',
24 | sources=['extensions/heatmap.cpp'],
25 | include_dirs=[np.get_include()],
26 | extra_compile_args=platform_specific_flags,
27 | language='c++')
28 | ]
29 |
30 | def checked_build(force=False):
31 | def do_build():
32 | sandbox.run_setup('setup_cpp_ext.py', ['build_ext', '--inplace'])
33 | try:
34 | import _C_canvas
35 | import _C_heatmap
36 | if force: do_build()
37 | except ImportError:
38 | do_build()
39 |
40 | if __name__ == "__main__":
41 | setuptools.setup(
42 | ext_modules=ext_modules
43 | )
--------------------------------------------------------------------------------
/styleclip_mapper.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch.nn import Module
6 |
7 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
8 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
9 | if input.ndim == 3:
10 | return (
11 | F.leaky_relu(
12 | input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
13 | )
14 | * scale
15 | )
16 | else:
17 | return (
18 | F.leaky_relu(
19 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
20 | )
21 | * scale
22 | )
23 |
24 | class PixelNorm(nn.Module):
25 | def __init__(self):
26 | super().__init__()
27 |
28 | def forward(self, input):
29 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
30 |
31 |
32 | class EqualLinear(nn.Module):
33 | def __init__(
34 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
35 | ):
36 | super().__init__()
37 |
38 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
39 |
40 | if bias:
41 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
42 |
43 | else:
44 | self.bias = None
45 |
46 | self.activation = activation
47 |
48 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
49 | self.lr_mul = lr_mul
50 |
51 | def forward(self, input):
52 | if self.activation:
53 | out = F.linear(input, self.weight * self.scale)
54 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
55 |
56 | else:
57 | out = F.linear(
58 | input, self.weight * self.scale, bias=self.bias * self.lr_mul
59 | )
60 |
61 | return out
62 |
63 | def __repr__(self):
64 | return (
65 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
66 | )
67 |
68 | class Mapper(Module):
69 |
70 | def __init__(self, opts):
71 | super(Mapper, self).__init__()
72 |
73 | self.opts = opts
74 | layers = [PixelNorm()]
75 |
76 | for i in range(4):
77 | layers.append(
78 | EqualLinear(
79 | 512, 512, lr_mul=0.01, activation='fused_lrelu'
80 | )
81 | )
82 |
83 | self.mapping = nn.Sequential(*layers)
84 |
85 |
86 | def forward(self, x):
87 | x = self.mapping(x)
88 | return x
89 |
90 |
91 | class SingleMapper(Module):
92 |
93 | def __init__(self, opts):
94 | super(SingleMapper, self).__init__()
95 |
96 | self.opts = opts
97 |
98 | self.mapping = Mapper(opts)
99 |
100 | def forward(self, x):
101 | out = self.mapping(x)
102 | return out
103 |
104 |
105 | class LevelsMapper(Module):
106 |
107 | def __init__(self, opts):
108 | super(LevelsMapper, self).__init__()
109 |
110 | self.opts = opts
111 |
112 | if not opts.no_coarse_mapper:
113 | self.course_mapping = Mapper(opts)
114 | if not opts.no_medium_mapper:
115 | self.medium_mapping = Mapper(opts)
116 | if not opts.no_fine_mapper:
117 | self.fine_mapping = Mapper(opts)
118 |
119 | def forward(self, x):
120 | x_coarse = x[:, :4, :]
121 | x_medium = x[:, 4:8, :]
122 | x_fine = x[:, 8:, :]
123 |
124 | if not self.opts.no_coarse_mapper:
125 | x_coarse = self.course_mapping(x_coarse)
126 | else:
127 | x_coarse = torch.zeros_like(x_coarse)
128 | if not self.opts.no_medium_mapper:
129 | x_medium = self.medium_mapping(x_medium)
130 | else:
131 | x_medium = torch.zeros_like(x_medium)
132 | if not self.opts.no_fine_mapper:
133 | x_fine = self.fine_mapping(x_fine)
134 | else:
135 | x_fine = torch.zeros_like(x_fine)
136 |
137 |
138 | out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
139 |
140 | return out
141 |
142 | def get_keys(d, name):
143 | if 'state_dict' in d:
144 | d = d['state_dict']
145 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
146 | return d_filt
147 |
148 |
149 | class StyleCLIPMapper(nn.Module):
150 |
151 | def __init__(self, opts):
152 | super().__init__()
153 | self.opts = opts
154 | # Define architecture
155 | self.mapper = self.set_mapper()
156 | # Load weights if needed
157 | self.load_weights()
158 |
159 | def set_mapper(self):
160 | if self.opts.mapper_type == 'SingleMapper':
161 | mapper = SingleMapper(self.opts)
162 | elif self.opts.mapper_type == 'LevelsMapper':
163 | mapper = LevelsMapper(self.opts)
164 | else:
165 | raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type))
166 | return mapper
167 |
168 | def load_weights(self):
169 | if self.opts.checkpoint_path is not None:
170 | print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
171 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
172 | self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True)
173 |
--------------------------------------------------------------------------------
/styleclip_presets.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | from resources import resource_list
7 |
8 |
9 | def pretrained_models():
10 | return {
11 | 'afro': resource_list['styleclip_afro'],
12 | 'bobcut': resource_list['styleclip_bobcut'],
13 | 'mohawk': resource_list['styleclip_mohawk'],
14 | 'bowlcut': resource_list['styleclip_bowlcut'],
15 | 'curly_hair': resource_list['styleclip_curly_hair'],
16 | }
--------------------------------------------------------------------------------
/stylegan_legacy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import click
10 | import pickle
11 | import re
12 | import copy
13 | import numpy as np
14 | import torch
15 | import dnnlib
16 | from torch_utils import misc
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def load_network_pkl(f, force_fp16=False):
21 | data = _LegacyUnpickler(f).load()
22 |
23 | # Legacy TensorFlow pickle => convert.
24 | if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
25 | tf_G, tf_D, tf_Gs = data
26 | G = convert_tf_generator(tf_G)
27 | D = convert_tf_discriminator(tf_D)
28 | G_ema = convert_tf_generator(tf_Gs)
29 | data = dict(G=G, D=D, G_ema=G_ema)
30 |
31 | # Add missing fields.
32 | if 'training_set_kwargs' not in data:
33 | data['training_set_kwargs'] = None
34 | if 'augment_pipe' not in data:
35 | data['augment_pipe'] = None
36 |
37 | # Validate contents.
38 | assert isinstance(data['G'], torch.nn.Module)
39 | assert isinstance(data['D'], torch.nn.Module)
40 | assert isinstance(data['G_ema'], torch.nn.Module)
41 | assert isinstance(data['training_set_kwargs'], (dict, type(None)))
42 | assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
43 |
44 | # Force FP16.
45 | if force_fp16:
46 | for key in ['G', 'D', 'G_ema']:
47 | old = data[key]
48 | kwargs = copy.deepcopy(old.init_kwargs)
49 | if key.startswith('G'):
50 | kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
51 | kwargs.synthesis_kwargs.num_fp16_res = 4
52 | kwargs.synthesis_kwargs.conv_clamp = 256
53 | if key.startswith('D'):
54 | kwargs.num_fp16_res = 4
55 | kwargs.conv_clamp = 256
56 | if kwargs != old.init_kwargs:
57 | new = type(old)(**kwargs).eval().requires_grad_(False)
58 | misc.copy_params_and_buffers(old, new, require_all=True)
59 | data[key] = new
60 | return data
61 |
62 | #----------------------------------------------------------------------------
63 |
64 | class _TFNetworkStub(dnnlib.EasyDict):
65 | pass
66 |
67 | class _LegacyUnpickler(pickle.Unpickler):
68 | def find_class(self, module, name):
69 | if module == 'dnnlib.tflib.network' and name == 'Network':
70 | return _TFNetworkStub
71 | return super().find_class(module, name)
72 |
73 | #----------------------------------------------------------------------------
74 |
75 | def _collect_tf_params(tf_net):
76 | # pylint: disable=protected-access
77 | tf_params = dict()
78 | def recurse(prefix, tf_net):
79 | for name, value in tf_net.variables:
80 | tf_params[prefix + name] = value
81 | for name, comp in tf_net.components.items():
82 | recurse(prefix + name + '/', comp)
83 | recurse('', tf_net)
84 | return tf_params
85 |
86 | #----------------------------------------------------------------------------
87 |
88 | def _populate_module_params(module, *patterns):
89 | for name, tensor in misc.named_params_and_buffers(module):
90 | found = False
91 | value = None
92 | for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
93 | match = re.fullmatch(pattern, name)
94 | if match:
95 | found = True
96 | if value_fn is not None:
97 | value = value_fn(*match.groups())
98 | break
99 | try:
100 | assert found
101 | if value is not None:
102 | tensor.copy_(torch.from_numpy(np.array(value)))
103 | except:
104 | print(name, list(tensor.shape))
105 | raise
106 |
107 | #----------------------------------------------------------------------------
108 |
109 | def convert_tf_generator(tf_G):
110 | if tf_G.version < 4:
111 | raise ValueError('TensorFlow pickle version too low')
112 |
113 | # Collect kwargs.
114 | tf_kwargs = tf_G.static_kwargs
115 | known_kwargs = set()
116 | def kwarg(tf_name, default=None, none=None):
117 | known_kwargs.add(tf_name)
118 | val = tf_kwargs.get(tf_name, default)
119 | return val if val is not None else none
120 |
121 | # Convert kwargs.
122 | kwargs = dnnlib.EasyDict(
123 | z_dim = kwarg('latent_size', 512),
124 | c_dim = kwarg('label_size', 0),
125 | w_dim = kwarg('dlatent_size', 512),
126 | img_resolution = kwarg('resolution', 1024),
127 | img_channels = kwarg('num_channels', 3),
128 | mapping_kwargs = dnnlib.EasyDict(
129 | num_layers = kwarg('mapping_layers', 8),
130 | embed_features = kwarg('label_fmaps', None),
131 | layer_features = kwarg('mapping_fmaps', None),
132 | activation = kwarg('mapping_nonlinearity', 'lrelu'),
133 | lr_multiplier = kwarg('mapping_lrmul', 0.01),
134 | w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
135 | ),
136 | synthesis_kwargs = dnnlib.EasyDict(
137 | channel_base = kwarg('fmap_base', 16384) * 2,
138 | channel_max = kwarg('fmap_max', 512),
139 | num_fp16_res = kwarg('num_fp16_res', 0),
140 | conv_clamp = kwarg('conv_clamp', None),
141 | architecture = kwarg('architecture', 'skip'),
142 | resample_filter = kwarg('resample_kernel', [1,3,3,1]),
143 | use_noise = kwarg('use_noise', True),
144 | activation = kwarg('nonlinearity', 'lrelu'),
145 | ),
146 | )
147 |
148 | # Check for unknown kwargs.
149 | kwarg('truncation_psi')
150 | kwarg('truncation_cutoff')
151 | kwarg('style_mixing_prob')
152 | kwarg('structure')
153 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154 | if len(unknown_kwargs) > 0:
155 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156 |
157 | # Collect params.
158 | tf_params = _collect_tf_params(tf_G)
159 | for name, value in list(tf_params.items()):
160 | match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161 | if match:
162 | r = kwargs.img_resolution // (2 ** int(match.group(1)))
163 | tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164 | kwargs.synthesis.kwargs.architecture = 'orig'
165 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166 |
167 | # Convert params.
168 | from training import networks
169 | G = networks.Generator(**kwargs).eval().requires_grad_(False)
170 | # pylint: disable=unnecessary-lambda
171 | _populate_module_params(G,
172 | r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173 | r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174 | r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177 | r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178 | r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179 | r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180 | r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181 | r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182 | r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183 | r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184 | r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185 | r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186 | r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187 | r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188 | r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189 | r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190 | r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191 | r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192 | r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193 | r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194 | r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195 | r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196 | r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197 | r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198 | r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199 | r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200 | r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201 | r'.*\.resample_filter', None,
202 | )
203 | return G
204 |
205 | #----------------------------------------------------------------------------
206 |
207 | def convert_tf_discriminator(tf_D):
208 | if tf_D.version < 4:
209 | raise ValueError('TensorFlow pickle version too low')
210 |
211 | # Collect kwargs.
212 | tf_kwargs = tf_D.static_kwargs
213 | known_kwargs = set()
214 | def kwarg(tf_name, default=None):
215 | known_kwargs.add(tf_name)
216 | return tf_kwargs.get(tf_name, default)
217 |
218 | # Convert kwargs.
219 | kwargs = dnnlib.EasyDict(
220 | c_dim = kwarg('label_size', 0),
221 | img_resolution = kwarg('resolution', 1024),
222 | img_channels = kwarg('num_channels', 3),
223 | architecture = kwarg('architecture', 'resnet'),
224 | channel_base = kwarg('fmap_base', 16384) * 2,
225 | channel_max = kwarg('fmap_max', 512),
226 | num_fp16_res = kwarg('num_fp16_res', 0),
227 | conv_clamp = kwarg('conv_clamp', None),
228 | cmap_dim = kwarg('mapping_fmaps', None),
229 | block_kwargs = dnnlib.EasyDict(
230 | activation = kwarg('nonlinearity', 'lrelu'),
231 | resample_filter = kwarg('resample_kernel', [1,3,3,1]),
232 | freeze_layers = kwarg('freeze_layers', 0),
233 | ),
234 | mapping_kwargs = dnnlib.EasyDict(
235 | num_layers = kwarg('mapping_layers', 0),
236 | embed_features = kwarg('mapping_fmaps', None),
237 | layer_features = kwarg('mapping_fmaps', None),
238 | activation = kwarg('nonlinearity', 'lrelu'),
239 | lr_multiplier = kwarg('mapping_lrmul', 0.1),
240 | ),
241 | epilogue_kwargs = dnnlib.EasyDict(
242 | mbstd_group_size = kwarg('mbstd_group_size', None),
243 | mbstd_num_channels = kwarg('mbstd_num_features', 1),
244 | activation = kwarg('nonlinearity', 'lrelu'),
245 | ),
246 | )
247 |
248 | # Check for unknown kwargs.
249 | kwarg('structure')
250 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
251 | if len(unknown_kwargs) > 0:
252 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
253 |
254 | # Collect params.
255 | tf_params = _collect_tf_params(tf_D)
256 | for name, value in list(tf_params.items()):
257 | match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
258 | if match:
259 | r = kwargs.img_resolution // (2 ** int(match.group(1)))
260 | tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
261 | kwargs.architecture = 'orig'
262 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
263 |
264 | # Convert params.
265 | from training import networks
266 | D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
267 | # pylint: disable=unnecessary-lambda
268 | _populate_module_params(D,
269 | r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
270 | r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
271 | r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
272 | r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
273 | r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
274 | r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
275 | r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
276 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
277 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
278 | r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
279 | r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
280 | r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
281 | r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
282 | r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
283 | r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
284 | r'.*\.resample_filter', None,
285 | )
286 | return D
287 |
288 | #----------------------------------------------------------------------------
289 |
290 | @click.command()
291 | @click.option('--source', help='Input pickle', required=True, metavar='PATH')
292 | @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
293 | @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
294 | def convert_network_pickle(source, dest, force_fp16):
295 | """Convert legacy network pickle into the native PyTorch format.
296 |
297 | The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
298 | It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
299 |
300 | Example:
301 |
302 | \b
303 | python legacy.py \\
304 | --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
305 | --dest=stylegan2-cat-config-f.pkl
306 | """
307 | print(f'Loading "{source}"...')
308 | with dnnlib.util.open_url(source) as f:
309 | data = load_network_pkl(f, force_fp16=force_fp16)
310 | print(f'Saving "{dest}"...')
311 | with open(dest, 'wb') as f:
312 | pickle.dump(data, f)
313 | print('Done.')
314 |
315 | #----------------------------------------------------------------------------
316 |
317 | if __name__ == "__main__":
318 | convert_network_pickle() # pylint: disable=no-value-for-parameter
319 |
320 | #----------------------------------------------------------------------------
321 |
--------------------------------------------------------------------------------
/stylegan_tune.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | # Straightforward Pivotal Tuning Implementation.
7 | import torch
8 | from lpips import LPIPS
9 | from stylegan_project import dilation
10 |
11 | # Changes the model itself. Make a clone if that is not desired.
12 | # Note that mask does not change dynamically.
13 | class PivotalTuning:
14 | def __init__(self, model, device, pivot_w, target_image, mask=None, alpha=1,
15 | lambda_l2=1, lr=3e-4):
16 | self.model = model
17 | self.device = device
18 | self.w = pivot_w
19 | self.target = target_image
20 | self.mask = mask
21 |
22 | self.alpha = alpha
23 | self.lambda_l2 = lambda_l2
24 | self.initial_lr = lr
25 | self.optimizer = None
26 | self.current_iter = 0
27 |
28 | def step(self):
29 | if self.optimizer is None:
30 | self._init_opt()
31 |
32 | self.optimizer.zero_grad()
33 | current_image = self.model(self.w)
34 | loss = self._calc_loss(current_image)
35 | loss.backward()
36 | self.optimizer.step()
37 |
38 | self.current_iter += 1
39 | return self.model
40 |
41 | def iters_done(self):
42 | return self.current_iter
43 |
44 | def _calc_loss(self, x):
45 | if self.mask is not None:
46 | expanded_mask = dilation(self.mask, torch.ones(7, 7, device=self.device))
47 | x = x * expanded_mask + self.target * (1 - expanded_mask)
48 |
49 | l2_loss = torch.nn.MSELoss()(x, self.target)
50 | lp_loss = self.lpips_loss(x, self.target)
51 | sphere_loss = 0 # Seems to be disabled in PTI anyway.
52 |
53 | loss = l2_loss + lp_loss
54 | return loss
55 |
56 | def _init_opt(self):
57 | self.model.requires_grad_(True)
58 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.initial_lr)
59 | self.lpips_loss = LPIPS().to(self.device)
60 |
--------------------------------------------------------------------------------
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/torch_utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import glob
11 | import torch
12 | import torch.utils.cpp_extension
13 | import importlib
14 | import hashlib
15 | import shutil
16 | from pathlib import Path
17 |
18 | from torch.utils.file_baton import FileBaton
19 |
20 | #----------------------------------------------------------------------------
21 | # Global options.
22 |
23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24 |
25 | #----------------------------------------------------------------------------
26 | # Internal helper funcs.
27 |
28 | def _find_compiler_bindir():
29 | patterns = [
30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34 | ]
35 | for pattern in patterns:
36 | matches = sorted(glob.glob(pattern))
37 | if len(matches):
38 | return matches[-1]
39 | return None
40 |
41 | #----------------------------------------------------------------------------
42 | # Main entry point for compiling and loading C++/CUDA plugins.
43 |
44 | _cached_plugins = dict()
45 |
46 | def get_plugin(module_name, sources, **build_kwargs):
47 | assert verbosity in ['none', 'brief', 'full']
48 |
49 | # Already cached?
50 | if module_name in _cached_plugins:
51 | return _cached_plugins[module_name]
52 |
53 | # Print status.
54 | if verbosity == 'full':
55 | print(f'Setting up PyTorch plugin "{module_name}"...')
56 | elif verbosity == 'brief':
57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58 |
59 | try: # pylint: disable=too-many-nested-blocks
60 | # Make sure we can find the necessary compiler binaries.
61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62 | compiler_bindir = _find_compiler_bindir()
63 | if compiler_bindir is None:
64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65 | os.environ['PATH'] += ';' + compiler_bindir
66 |
67 | # Compile and load.
68 | verbose_build = (verbosity == 'full')
69 |
70 | # Incremental build md5sum trickery. Copies all the input source files
71 | # into a cached build directory under a combined md5 digest of the input
72 | # source files. Copying is done only if the combined digest has changed.
73 | # This keeps input file timestamps and filenames the same as in previous
74 | # extension builds, allowing for fast incremental rebuilds.
75 | #
76 | # This optimization is done only in case all the source files reside in
77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78 | # environment variable is set (we take this as a signal that the user
79 | # actually cares about this.)
80 | source_dirs_set = set(os.path.dirname(source) for source in sources)
81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83 |
84 | # Compute a combined hash digest for all source files in the same
85 | # custom op directory (usually .cu, .cpp, .py and .h files).
86 | hash_md5 = hashlib.md5()
87 | for src in all_source_files:
88 | with open(src, 'rb') as f:
89 | hash_md5.update(f.read())
90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92 |
93 | if not os.path.isdir(digest_build_dir):
94 | os.makedirs(digest_build_dir, exist_ok=True)
95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96 | if baton.try_acquire():
97 | try:
98 | for src in all_source_files:
99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100 | finally:
101 | baton.release()
102 | else:
103 | # Someone else is copying source files under the digest dir,
104 | # wait until done and continue.
105 | baton.wait()
106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108 | verbose=verbose_build, sources=digest_sources, **build_kwargs)
109 | else:
110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111 | module = importlib.import_module(module_name)
112 |
113 | except:
114 | if verbosity == 'brief':
115 | print('Failed!')
116 | raise
117 |
118 | # Print status and add to cache.
119 | if verbosity == 'full':
120 | print(f'Done setting up PyTorch plugin "{module_name}".')
121 | elif verbosity == 'brief':
122 | print('Done.')
123 | _cached_plugins[module_name] = module
124 | return module
125 |
126 | #----------------------------------------------------------------------------
127 |
--------------------------------------------------------------------------------
/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import re
10 | import contextlib
11 | import numpy as np
12 | import torch
13 | import warnings
14 | import dnnlib
15 |
16 | #----------------------------------------------------------------------------
17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18 | # same constant is used multiple times.
19 |
20 | _constant_cache = dict()
21 |
22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23 | value = np.asarray(value)
24 | if shape is not None:
25 | shape = tuple(shape)
26 | if dtype is None:
27 | dtype = torch.get_default_dtype()
28 | if device is None:
29 | device = torch.device('cpu')
30 | if memory_format is None:
31 | memory_format = torch.contiguous_format
32 |
33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34 | tensor = _constant_cache.get(key, None)
35 | if tensor is None:
36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37 | if shape is not None:
38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39 | tensor = tensor.contiguous(memory_format=memory_format)
40 | _constant_cache[key] = tensor
41 | return tensor
42 |
43 | #----------------------------------------------------------------------------
44 | # Replace NaN/Inf with specified numerical values.
45 |
46 | try:
47 | nan_to_num = torch.nan_to_num # 1.8.0a0
48 | except AttributeError:
49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50 | assert isinstance(input, torch.Tensor)
51 | if posinf is None:
52 | posinf = torch.finfo(input.dtype).max
53 | if neginf is None:
54 | neginf = torch.finfo(input.dtype).min
55 | assert nan == 0
56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57 |
58 | #----------------------------------------------------------------------------
59 | # Symbolic assert.
60 |
61 | try:
62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63 | except AttributeError:
64 | symbolic_assert = torch.Assert # 1.7.0
65 |
66 | #----------------------------------------------------------------------------
67 | # Context manager to suppress known warnings in torch.jit.trace().
68 |
69 | class suppress_tracer_warnings(warnings.catch_warnings):
70 | def __enter__(self):
71 | super().__enter__()
72 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
73 | return self
74 |
75 | #----------------------------------------------------------------------------
76 | # Assert that the shape of a tensor matches the given list of integers.
77 | # None indicates that the size of a dimension is allowed to vary.
78 | # Performs symbolic assertion when used in torch.jit.trace().
79 |
80 | def assert_shape(tensor, ref_shape):
81 | if tensor.ndim != len(ref_shape):
82 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
83 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
84 | if ref_size is None:
85 | pass
86 | elif isinstance(ref_size, torch.Tensor):
87 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
88 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
89 | elif isinstance(size, torch.Tensor):
90 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
91 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
92 | elif size != ref_size:
93 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
94 |
95 | #----------------------------------------------------------------------------
96 | # Function decorator that calls torch.autograd.profiler.record_function().
97 |
98 | def profiled_function(fn):
99 | def decorator(*args, **kwargs):
100 | with torch.autograd.profiler.record_function(fn.__name__):
101 | return fn(*args, **kwargs)
102 | decorator.__name__ = fn.__name__
103 | return decorator
104 |
105 | #----------------------------------------------------------------------------
106 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
107 | # indefinitely, shuffling items as it goes.
108 |
109 | class InfiniteSampler(torch.utils.data.Sampler):
110 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
111 | assert len(dataset) > 0
112 | assert num_replicas > 0
113 | assert 0 <= rank < num_replicas
114 | assert 0 <= window_size <= 1
115 | super().__init__(dataset)
116 | self.dataset = dataset
117 | self.rank = rank
118 | self.num_replicas = num_replicas
119 | self.shuffle = shuffle
120 | self.seed = seed
121 | self.window_size = window_size
122 |
123 | def __iter__(self):
124 | order = np.arange(len(self.dataset))
125 | rnd = None
126 | window = 0
127 | if self.shuffle:
128 | rnd = np.random.RandomState(self.seed)
129 | rnd.shuffle(order)
130 | window = int(np.rint(order.size * self.window_size))
131 |
132 | idx = 0
133 | while True:
134 | i = idx % order.size
135 | if idx % self.num_replicas == self.rank:
136 | yield order[i]
137 | if window >= 2:
138 | j = (i - rnd.randint(window)) % order.size
139 | order[i], order[j] = order[j], order[i]
140 | idx += 1
141 |
142 | #----------------------------------------------------------------------------
143 | # Utilities for operating with torch.nn.Module parameters and buffers.
144 |
145 | def params_and_buffers(module):
146 | assert isinstance(module, torch.nn.Module)
147 | return list(module.parameters()) + list(module.buffers())
148 |
149 | def named_params_and_buffers(module):
150 | assert isinstance(module, torch.nn.Module)
151 | return list(module.named_parameters()) + list(module.named_buffers())
152 |
153 | def copy_params_and_buffers(src_module, dst_module, require_all=False, rand_init_extra_channels=False):
154 | assert isinstance(src_module, torch.nn.Module)
155 | assert isinstance(dst_module, torch.nn.Module)
156 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
157 | for name, tensor in named_params_and_buffers(dst_module):
158 | assert (name in src_tensors) or (not require_all)
159 | if name in src_tensors:
160 | # Only implement rand_fill_extra_channels for torgb for now, to catch future mistakes.
161 | if 'torgb' in name and rand_init_extra_channels and tensor.shape[0] != src_tensors[name].shape[0]:
162 | print(f"<{name}>: Initializing first {src_tensors[name].shape[0]} channels with existing weights")
163 | # tensor[:src_tensors[name].shape[0], ...].copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
164 | elif 'fromrgb' in name and rand_init_extra_channels and tensor.size() != src_tensors[name].size():
165 | assert tensor.ndim > 1 and tensor.shape[1] != src_tensors[name].shape[1], (tensor.shape, src_tensors[name].shape)
166 | print(f"<{name}>: Initializing first {src_tensors[name].shape[1]} channels with existing weights")
167 | # tensor[:, :src_tensors[name].shape[1], ...].copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
168 | elif tensor.size() != src_tensors[name].size():
169 | print(f"<{name}>: Tensors differ in size, not initializing")
170 | print(f"{tensor.shape} vs {src_tensors[name].shape}")
171 | else:
172 | # print(name, tensor.shape, src_tensors[name].shape)
173 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
174 |
175 | #----------------------------------------------------------------------------
176 | # Context manager for easily enabling/disabling DistributedDataParallel
177 | # synchronization.
178 |
179 | @contextlib.contextmanager
180 | def ddp_sync(module, sync):
181 | assert isinstance(module, torch.nn.Module)
182 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
183 | yield
184 | else:
185 | with module.no_sync():
186 | yield
187 |
188 | #----------------------------------------------------------------------------
189 | # Check DistributedDataParallel consistency across processes.
190 |
191 | def check_ddp_consistency(module, ignore_regex=None):
192 | assert isinstance(module, torch.nn.Module)
193 | for name, tensor in named_params_and_buffers(module):
194 | fullname = type(module).__name__ + '.' + name
195 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
196 | continue
197 | tensor = tensor.detach()
198 | other = tensor.clone()
199 | torch.distributed.broadcast(tensor=other, src=0)
200 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
201 |
202 | #----------------------------------------------------------------------------
203 | # Print summary table of module hierarchy.
204 |
205 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
206 | assert isinstance(module, torch.nn.Module)
207 | assert not isinstance(module, torch.jit.ScriptModule)
208 | assert isinstance(inputs, (tuple, list))
209 |
210 | # Register hooks.
211 | entries = []
212 | nesting = [0]
213 | def pre_hook(_mod, _inputs):
214 | nesting[0] += 1
215 | def post_hook(mod, _inputs, outputs):
216 | nesting[0] -= 1
217 | if nesting[0] <= max_nesting:
218 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
219 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
220 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
221 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
222 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
223 |
224 | # Run module.
225 | outputs = module(*inputs)
226 | for hook in hooks:
227 | hook.remove()
228 |
229 | # Identify unique outputs, parameters, and buffers.
230 | tensors_seen = set()
231 | for e in entries:
232 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
233 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
234 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
235 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
236 |
237 | # Filter out redundant entries.
238 | if skip_redundant:
239 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
240 |
241 | # Construct table.
242 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
243 | rows += [['---'] * len(rows[0])]
244 | param_total = 0
245 | buffer_total = 0
246 | submodule_names = {mod: name for name, mod in module.named_modules()}
247 | for e in entries:
248 | name = '' if e.mod is module else submodule_names[e.mod]
249 | param_size = sum(t.numel() for t in e.unique_params)
250 | buffer_size = sum(t.numel() for t in e.unique_buffers)
251 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
252 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
253 | rows += [[
254 | name + (':0' if len(e.outputs) >= 2 else ''),
255 | str(param_size) if param_size else '-',
256 | str(buffer_size) if buffer_size else '-',
257 | (output_shapes + ['-'])[0],
258 | (output_dtypes + ['-'])[0],
259 | ]]
260 | for idx in range(1, len(e.outputs)):
261 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
262 | param_total += param_size
263 | buffer_total += buffer_size
264 | rows += [['---'] * len(rows[0])]
265 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
266 |
267 | # Print table.
268 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
269 | print()
270 | for row in rows:
271 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
272 | print()
273 | return outputs
274 |
275 | #----------------------------------------------------------------------------
276 |
--------------------------------------------------------------------------------
/torch_utils/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "bias_act.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17 | {
18 | if (x.dim() != y.dim())
19 | return false;
20 | for (int64_t i = 0; i < x.dim(); i++)
21 | {
22 | if (x.size(i) != y.size(i))
23 | return false;
24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25 | return false;
26 | }
27 | return true;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 |
32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33 | {
34 | // Validate arguments.
35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
45 |
46 | // Validate layout.
47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52 |
53 | // Create output tensor.
54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55 | torch::Tensor y = torch::empty_like(x);
56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57 |
58 | // Initialize CUDA kernel parameters.
59 | bias_act_kernel_params p;
60 | p.x = x.data_ptr();
61 | p.b = (b.numel()) ? b.data_ptr() : NULL;
62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65 | p.y = y.data_ptr();
66 | p.grad = grad;
67 | p.act = act;
68 | p.alpha = alpha;
69 | p.gain = gain;
70 | p.clamp = clamp;
71 | p.sizeX = (int)x.numel();
72 | p.sizeB = (int)b.numel();
73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74 |
75 | // Choose CUDA kernel.
76 | void* kernel;
77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78 | {
79 | kernel = choose_bias_act_kernel(p);
80 | });
81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82 |
83 | // Launch CUDA kernel.
84 | p.loopX = 4;
85 | int blockSize = 4 * 32;
86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87 | void* args[] = {&p};
88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89 | return y;
90 | }
91 |
92 | //------------------------------------------------------------------------
93 |
94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95 | {
96 | m.def("bias_act", &bias_act);
97 | }
98 |
99 | //------------------------------------------------------------------------
100 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include "bias_act.h"
11 |
12 | //------------------------------------------------------------------------
13 | // Helpers.
14 |
15 | template struct InternalType;
16 | template <> struct InternalType { typedef double scalar_t; };
17 | template <> struct InternalType { typedef float scalar_t; };
18 | template <> struct InternalType { typedef float scalar_t; };
19 |
20 | //------------------------------------------------------------------------
21 | // CUDA kernel.
22 |
23 | template
24 | __global__ void bias_act_kernel(bias_act_kernel_params p)
25 | {
26 | typedef typename InternalType::scalar_t scalar_t;
27 | int G = p.grad;
28 | scalar_t alpha = (scalar_t)p.alpha;
29 | scalar_t gain = (scalar_t)p.gain;
30 | scalar_t clamp = (scalar_t)p.clamp;
31 | scalar_t one = (scalar_t)1;
32 | scalar_t two = (scalar_t)2;
33 | scalar_t expRange = (scalar_t)80;
34 | scalar_t halfExpRange = (scalar_t)40;
35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37 |
38 | // Loop over elements.
39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41 | {
42 | // Load.
43 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48 | scalar_t yy = (gain != 0) ? yref / gain : 0;
49 | scalar_t y = 0;
50 |
51 | // Apply bias.
52 | ((G == 0) ? x : xref) += b;
53 |
54 | // linear
55 | if (A == 1)
56 | {
57 | if (G == 0) y = x;
58 | if (G == 1) y = x;
59 | }
60 |
61 | // relu
62 | if (A == 2)
63 | {
64 | if (G == 0) y = (x > 0) ? x : 0;
65 | if (G == 1) y = (yy > 0) ? x : 0;
66 | }
67 |
68 | // lrelu
69 | if (A == 3)
70 | {
71 | if (G == 0) y = (x > 0) ? x : x * alpha;
72 | if (G == 1) y = (yy > 0) ? x : x * alpha;
73 | }
74 |
75 | // tanh
76 | if (A == 4)
77 | {
78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79 | if (G == 1) y = x * (one - yy * yy);
80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81 | }
82 |
83 | // sigmoid
84 | if (A == 5)
85 | {
86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87 | if (G == 1) y = x * yy * (one - yy);
88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89 | }
90 |
91 | // elu
92 | if (A == 6)
93 | {
94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97 | }
98 |
99 | // selu
100 | if (A == 7)
101 | {
102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105 | }
106 |
107 | // softplus
108 | if (A == 8)
109 | {
110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111 | if (G == 1) y = x * (one - exp(-yy));
112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113 | }
114 |
115 | // swish
116 | if (A == 9)
117 | {
118 | if (G == 0)
119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120 | else
121 | {
122 | scalar_t c = exp(xref);
123 | scalar_t d = c + one;
124 | if (G == 1)
125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126 | else
127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129 | }
130 | }
131 |
132 | // Apply gain.
133 | y *= gain * dy;
134 |
135 | // Clamp.
136 | if (clamp >= 0)
137 | {
138 | if (G == 0)
139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140 | else
141 | y = (yref > -clamp & yref < clamp) ? y : 0;
142 | }
143 |
144 | // Store.
145 | ((T*)p.y)[xi] = (T)y;
146 | }
147 | }
148 |
149 | //------------------------------------------------------------------------
150 | // CUDA kernel selection.
151 |
152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153 | {
154 | if (p.act == 1) return (void*)bias_act_kernel;
155 | if (p.act == 2) return (void*)bias_act_kernel;
156 | if (p.act == 3) return (void*)bias_act_kernel;
157 | if (p.act == 4) return (void*)bias_act_kernel;
158 | if (p.act == 5) return (void*)bias_act_kernel;
159 | if (p.act == 6) return (void*)bias_act_kernel;
160 | if (p.act == 7) return (void*)bias_act_kernel;
161 | if (p.act == 8) return (void*)bias_act_kernel;
162 | if (p.act == 9) return (void*)bias_act_kernel;
163 | return NULL;
164 | }
165 |
166 | //------------------------------------------------------------------------
167 | // Template specializations.
168 |
169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
172 |
173 | //------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom PyTorch ops for efficient bias and activation."""
10 |
11 | import os
12 | import sys
13 | import warnings
14 | import numpy as np
15 | import torch
16 | import dnnlib
17 | import traceback
18 |
19 | from .. import custom_ops
20 | from .. import misc
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | activation_funcs = {
25 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
26 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
27 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
28 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
29 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
30 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
31 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
32 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
33 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
34 | }
35 |
36 | #----------------------------------------------------------------------------
37 |
38 | _inited = False
39 | _plugin = None
40 | _null_tensor = torch.empty([0])
41 |
42 | def _init():
43 | global _inited, _plugin
44 | if not _inited:
45 | _inited = True
46 | sources = ['bias_act.cpp', 'bias_act.cu']
47 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
48 | try:
49 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
50 | except:
51 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
52 | return _plugin is not None
53 |
54 | #----------------------------------------------------------------------------
55 |
56 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
57 | r"""Fused bias and activation function.
58 |
59 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
60 | and scales the result by `gain`. Each of the steps is optional. In most cases,
61 | the fused op is considerably more efficient than performing the same calculation
62 | using standard PyTorch ops. It supports first and second order gradients,
63 | but not third order gradients.
64 |
65 | Args:
66 | x: Input activation tensor. Can be of any shape.
67 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
68 | as `x`. The shape must be known, and it must match the dimension of `x`
69 | corresponding to `dim`.
70 | dim: The dimension in `x` corresponding to the elements of `b`.
71 | The value of `dim` is ignored if `b` is not specified.
72 | act: Name of the activation function to evaluate, or `"linear"` to disable.
73 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
74 | See `activation_funcs` for a full list. `None` is not allowed.
75 | alpha: Shape parameter for the activation function, or `None` to use the default.
76 | gain: Scaling factor for the output tensor, or `None` to use default.
77 | See `activation_funcs` for the default scaling of each activation function.
78 | If unsure, consider specifying 1.
79 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
80 | the clamping (default).
81 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
82 |
83 | Returns:
84 | Tensor of the same shape and datatype as `x`.
85 | """
86 | assert isinstance(x, torch.Tensor)
87 | assert impl in ['ref', 'cuda']
88 | if impl == 'cuda' and x.device.type == 'cuda' and _init():
89 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
90 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
91 |
92 | #----------------------------------------------------------------------------
93 |
94 | @misc.profiled_function
95 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
96 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
97 | """
98 | assert isinstance(x, torch.Tensor)
99 | assert clamp is None or clamp >= 0
100 | spec = activation_funcs[act]
101 | alpha = float(alpha if alpha is not None else spec.def_alpha)
102 | gain = float(gain if gain is not None else spec.def_gain)
103 | clamp = float(clamp if clamp is not None else -1)
104 |
105 | # Add bias.
106 | if b is not None:
107 | assert isinstance(b, torch.Tensor) and b.ndim == 1
108 | assert 0 <= dim < x.ndim
109 | assert b.shape[0] == x.shape[dim]
110 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
111 |
112 | # Evaluate activation function.
113 | alpha = float(alpha)
114 | x = spec.func(x, alpha=alpha)
115 |
116 | # Scale by gain.
117 | gain = float(gain)
118 | if gain != 1:
119 | x = x * gain
120 |
121 | # Clamp.
122 | if clamp >= 0:
123 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
124 | return x
125 |
126 | #----------------------------------------------------------------------------
127 |
128 | _bias_act_cuda_cache = dict()
129 |
130 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
131 | """Fast CUDA implementation of `bias_act()` using custom ops.
132 | """
133 | # Parse arguments.
134 | assert clamp is None or clamp >= 0
135 | spec = activation_funcs[act]
136 | alpha = float(alpha if alpha is not None else spec.def_alpha)
137 | gain = float(gain if gain is not None else spec.def_gain)
138 | clamp = float(clamp if clamp is not None else -1)
139 |
140 | # Lookup from cache.
141 | key = (dim, act, alpha, gain, clamp)
142 | if key in _bias_act_cuda_cache:
143 | return _bias_act_cuda_cache[key]
144 |
145 | # Forward op.
146 | class BiasActCuda(torch.autograd.Function):
147 | @staticmethod
148 | def forward(ctx, x, b): # pylint: disable=arguments-differ
149 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
150 | x = x.contiguous(memory_format=ctx.memory_format)
151 | b = b.contiguous() if b is not None else _null_tensor
152 | y = x
153 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
154 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
155 | ctx.save_for_backward(
156 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
157 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
158 | y if 'y' in spec.ref else _null_tensor)
159 | return y
160 |
161 | @staticmethod
162 | def backward(ctx, dy): # pylint: disable=arguments-differ
163 | dy = dy.contiguous(memory_format=ctx.memory_format)
164 | x, b, y = ctx.saved_tensors
165 | dx = None
166 | db = None
167 |
168 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
169 | dx = dy
170 | if act != 'linear' or gain != 1 or clamp >= 0:
171 | dx = BiasActCudaGrad.apply(dy, x, b, y)
172 |
173 | if ctx.needs_input_grad[1]:
174 | db = dx.sum([i for i in range(dx.ndim) if i != dim])
175 |
176 | return dx, db
177 |
178 | # Backward op.
179 | class BiasActCudaGrad(torch.autograd.Function):
180 | @staticmethod
181 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
182 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
183 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
184 | ctx.save_for_backward(
185 | dy if spec.has_2nd_grad else _null_tensor,
186 | x, b, y)
187 | return dx
188 |
189 | @staticmethod
190 | def backward(ctx, d_dx): # pylint: disable=arguments-differ
191 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
192 | dy, x, b, y = ctx.saved_tensors
193 | d_dy = None
194 | d_x = None
195 | d_b = None
196 | d_y = None
197 |
198 | if ctx.needs_input_grad[0]:
199 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
200 |
201 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
202 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
203 |
204 | if spec.has_2nd_grad and ctx.needs_input_grad[2]:
205 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
206 |
207 | return d_dy, d_x, d_b, d_y
208 |
209 | # Add to cache.
210 | _bias_act_cuda_cache[key] = BiasActCuda
211 | return BiasActCuda
212 |
213 | #----------------------------------------------------------------------------
214 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.conv2d` that supports
10 | arbitrarily high order gradients with zero performance penalty."""
11 |
12 | import warnings
13 | import contextlib
14 | import torch
15 |
16 | # pylint: disable=redefined-builtin
17 | # pylint: disable=arguments-differ
18 | # pylint: disable=protected-access
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | enabled = False # Enable the custom op by setting this to true.
23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24 |
25 | @contextlib.contextmanager
26 | def no_weight_gradients():
27 | global weight_gradients_disabled
28 | old = weight_gradients_disabled
29 | weight_gradients_disabled = True
30 | yield
31 | weight_gradients_disabled = old
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36 | if _should_use_custom_op(input):
37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39 |
40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41 | if _should_use_custom_op(input):
42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44 |
45 | #----------------------------------------------------------------------------
46 |
47 | def _should_use_custom_op(input):
48 | assert isinstance(input, torch.Tensor)
49 | if (not enabled) or (not torch.backends.cudnn.enabled):
50 | return False
51 | if input.device.type != 'cuda':
52 | return False
53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.']):
54 | return True
55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
56 | return False
57 |
58 | def _tuple_of_ints(xs, ndim):
59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
60 | assert len(xs) == ndim
61 | assert all(isinstance(x, int) for x in xs)
62 | return xs
63 |
64 | #----------------------------------------------------------------------------
65 |
66 | _conv2d_gradfix_cache = dict()
67 |
68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69 | # Parse arguments.
70 | ndim = 2
71 | weight_shape = tuple(weight_shape)
72 | stride = _tuple_of_ints(stride, ndim)
73 | padding = _tuple_of_ints(padding, ndim)
74 | output_padding = _tuple_of_ints(output_padding, ndim)
75 | dilation = _tuple_of_ints(dilation, ndim)
76 |
77 | # Lookup from cache.
78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79 | if key in _conv2d_gradfix_cache:
80 | return _conv2d_gradfix_cache[key]
81 |
82 | # Validate arguments.
83 | assert groups >= 1
84 | assert len(weight_shape) == ndim + 2
85 | assert all(stride[i] >= 1 for i in range(ndim))
86 | assert all(padding[i] >= 0 for i in range(ndim))
87 | assert all(dilation[i] >= 0 for i in range(ndim))
88 | if not transpose:
89 | assert all(output_padding[i] == 0 for i in range(ndim))
90 | else: # transpose
91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92 |
93 | # Helpers.
94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95 | def calc_output_padding(input_shape, output_shape):
96 | if transpose:
97 | return [0, 0]
98 | return [
99 | input_shape[i + 2]
100 | - (output_shape[i + 2] - 1) * stride[i]
101 | - (1 - 2 * padding[i])
102 | - dilation[i] * (weight_shape[i + 2] - 1)
103 | for i in range(ndim)
104 | ]
105 |
106 | # Forward & backward.
107 | class Conv2d(torch.autograd.Function):
108 | @staticmethod
109 | def forward(ctx, input, weight, bias):
110 | assert weight.shape == weight_shape
111 | if not transpose:
112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
113 | else: # transpose
114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
115 | ctx.save_for_backward(input, weight)
116 | return output
117 |
118 | @staticmethod
119 | def backward(ctx, grad_output):
120 | input, weight = ctx.saved_tensors
121 | grad_input = None
122 | grad_weight = None
123 | grad_bias = None
124 |
125 | if ctx.needs_input_grad[0]:
126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
128 | assert grad_input.shape == input.shape
129 |
130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
131 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
132 | assert grad_weight.shape == weight_shape
133 |
134 | if ctx.needs_input_grad[2]:
135 | grad_bias = grad_output.sum([0, 2, 3])
136 |
137 | return grad_input, grad_weight, grad_bias
138 |
139 | # Gradient with respect to the weights.
140 | class Conv2dGradWeight(torch.autograd.Function):
141 | @staticmethod
142 | def forward(ctx, grad_output, input):
143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
146 | assert grad_weight.shape == weight_shape
147 | ctx.save_for_backward(grad_output, input)
148 | return grad_weight
149 |
150 | @staticmethod
151 | def backward(ctx, grad2_grad_weight):
152 | grad_output, input = ctx.saved_tensors
153 | grad2_grad_output = None
154 | grad2_input = None
155 |
156 | if ctx.needs_input_grad[0]:
157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
158 | assert grad2_grad_output.shape == grad_output.shape
159 |
160 | if ctx.needs_input_grad[1]:
161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
163 | assert grad2_input.shape == input.shape
164 |
165 | return grad2_grad_output, grad2_input
166 |
167 | _conv2d_gradfix_cache[key] = Conv2d
168 | return Conv2d
169 |
170 | #----------------------------------------------------------------------------
171 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """2D convolution with optional up/downsampling."""
10 |
11 | import torch
12 |
13 | from .. import misc
14 | from . import conv2d_gradfix
15 | from . import upfirdn2d
16 | from .upfirdn2d import _parse_padding
17 | from .upfirdn2d import _get_filter_size
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def _get_weight_shape(w):
22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23 | shape = [int(sz) for sz in w.shape]
24 | misc.assert_shape(w, shape)
25 | return shape
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31 | """
32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
33 |
34 | # Flip weight if requested.
35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36 | w = w.flip([2, 3])
37 |
38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels.
40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
42 | if out_channels <= 4 and groups == 1:
43 | in_shape = x.shape
44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
46 | else:
47 | x = x.to(memory_format=torch.contiguous_format)
48 | w = w.to(memory_format=torch.contiguous_format)
49 | x = conv2d_gradfix.conv2d(x, w, groups=groups)
50 | return x.to(memory_format=torch.channels_last)
51 |
52 | # Otherwise => execute using conv2d_gradfix.
53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
54 | return op(x, w, stride=stride, padding=padding, groups=groups)
55 |
56 | #----------------------------------------------------------------------------
57 |
58 | @misc.profiled_function
59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
60 | r"""2D convolution with optional up/downsampling.
61 |
62 | Padding is performed only once at the beginning, not between the operations.
63 |
64 | Args:
65 | x: Input tensor of shape
66 | `[batch_size, in_channels, in_height, in_width]`.
67 | w: Weight tensor of shape
68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
70 | calling upfirdn2d.setup_filter(). None = identity (default).
71 | up: Integer upsampling factor (default: 1).
72 | down: Integer downsampling factor (default: 1).
73 | padding: Padding with respect to the upsampled image. Can be a single number
74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
75 | (default: 0).
76 | groups: Split input channels into N groups (default: 1).
77 | flip_weight: False = convolution, True = correlation (default: True).
78 | flip_filter: False = convolution, True = correlation (default: False).
79 |
80 | Returns:
81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
82 | """
83 | # Validate arguments.
84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
87 | assert isinstance(up, int) and (up >= 1)
88 | assert isinstance(down, int) and (down >= 1)
89 | assert isinstance(groups, int) and (groups >= 1)
90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
91 | fw, fh = _get_filter_size(f)
92 | px0, px1, py0, py1 = _parse_padding(padding)
93 |
94 | # Adjust padding to account for up/downsampling.
95 | if up > 1:
96 | px0 += (fw + up - 1) // 2
97 | px1 += (fw - up) // 2
98 | py0 += (fh + up - 1) // 2
99 | py1 += (fh - up) // 2
100 | if down > 1:
101 | px0 += (fw - down + 1) // 2
102 | px1 += (fw - down) // 2
103 | py0 += (fh - down + 1) // 2
104 | py1 += (fh - down) // 2
105 |
106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
107 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
110 | return x
111 |
112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
113 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
116 | return x
117 |
118 | # Fast path: downsampling only => use strided convolution.
119 | if down > 1 and up == 1:
120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
122 | return x
123 |
124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
125 | if up > 1:
126 | if groups == 1:
127 | w = w.transpose(0, 1)
128 | else:
129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
130 | w = w.transpose(1, 2)
131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
132 | px0 -= kw - 1
133 | px1 -= kw - up
134 | py0 -= kh - 1
135 | py1 -= kh - up
136 | pxt = max(min(-px0, -px1), 0)
137 | pyt = max(min(-py0, -py1), 0)
138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
140 | if down > 1:
141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
142 | return x
143 |
144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
145 | if up == 1 and down == 1:
146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
148 |
149 | # Fallback: Generic reference implementation.
150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
152 | if down > 1:
153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
154 | return x
155 |
156 | #----------------------------------------------------------------------------
157 |
--------------------------------------------------------------------------------
/torch_utils/ops/fma.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10 |
11 | import torch
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | def fma(a, b, c): # => a * b + c
16 | return _FusedMultiplyAdd.apply(a, b, c)
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21 | @staticmethod
22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23 | out = torch.addcmul(c, a, b)
24 | ctx.save_for_backward(a, b)
25 | ctx.c_shape = c.shape
26 | return out
27 |
28 | @staticmethod
29 | def backward(ctx, dout): # pylint: disable=arguments-differ
30 | a, b = ctx.saved_tensors
31 | c_shape = ctx.c_shape
32 | da = None
33 | db = None
34 | dc = None
35 |
36 | if ctx.needs_input_grad[0]:
37 | da = _unbroadcast(dout * b, a.shape)
38 |
39 | if ctx.needs_input_grad[1]:
40 | db = _unbroadcast(dout * a, b.shape)
41 |
42 | if ctx.needs_input_grad[2]:
43 | dc = _unbroadcast(dout, c_shape)
44 |
45 | return da, db, dc
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _unbroadcast(x, shape):
50 | extra_dims = x.ndim - len(shape)
51 | assert extra_dims >= 0
52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53 | if len(dim):
54 | x = x.sum(dim=dim, keepdim=True)
55 | if extra_dims:
56 | x = x.reshape(-1, *x.shape[extra_dims+1:])
57 | assert x.shape == shape
58 | return x
59 |
60 | #----------------------------------------------------------------------------
61 |
--------------------------------------------------------------------------------
/torch_utils/ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.grid_sample` that
10 | supports arbitrarily high order gradients between the input and output.
11 | Only works on 2D images and assumes
12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13 |
14 | import warnings
15 | import torch
16 |
17 | # pylint: disable=redefined-builtin
18 | # pylint: disable=arguments-differ
19 | # pylint: disable=protected-access
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | enabled = False # Enable the custom op by setting this to true.
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | def grid_sample(input, grid):
28 | if _should_use_custom_op():
29 | return _GridSample2dForward.apply(input, grid)
30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def _should_use_custom_op():
35 | if not enabled:
36 | return False
37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.']):
38 | return True
39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
40 | return False
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | class _GridSample2dForward(torch.autograd.Function):
45 | @staticmethod
46 | def forward(ctx, input, grid):
47 | assert input.ndim == 4
48 | assert grid.ndim == 4
49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
50 | ctx.save_for_backward(input, grid)
51 | return output
52 |
53 | @staticmethod
54 | def backward(ctx, grad_output):
55 | input, grid = ctx.saved_tensors
56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
57 | return grad_input, grad_grid
58 |
59 | #----------------------------------------------------------------------------
60 |
61 | class _GridSample2dBackward(torch.autograd.Function):
62 | @staticmethod
63 | def forward(ctx, grad_output, input, grid):
64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66 | ctx.save_for_backward(grid)
67 | return grad_input, grad_grid
68 |
69 | @staticmethod
70 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
71 | _ = grad2_grad_grid # unused
72 | grid, = ctx.saved_tensors
73 | grad2_grad_output = None
74 | grad2_input = None
75 | grad2_grid = None
76 |
77 | if ctx.needs_input_grad[0]:
78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79 |
80 | assert not ctx.needs_input_grad[2]
81 | return grad2_grad_output, grad2_input, grad2_grid
82 |
83 | #----------------------------------------------------------------------------
84 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "upfirdn2d.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17 | {
18 | // Validate arguments.
19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29 |
30 | // Create output tensor.
31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37 |
38 | // Initialize CUDA kernel parameters.
39 | upfirdn2d_kernel_params p;
40 | p.x = x.data_ptr();
41 | p.f = f.data_ptr();
42 | p.y = y.data_ptr();
43 | p.up = make_int2(upx, upy);
44 | p.down = make_int2(downx, downy);
45 | p.pad0 = make_int2(padx0, pady0);
46 | p.flip = (flip) ? 1 : 0;
47 | p.gain = gain;
48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56 |
57 | // Choose CUDA kernel.
58 | upfirdn2d_kernel_spec spec;
59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60 | {
61 | spec = choose_upfirdn2d_kernel(p);
62 | });
63 |
64 | // Set looping options.
65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66 | p.loopMinor = spec.loopMinor;
67 | p.loopX = spec.loopX;
68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70 |
71 | // Compute grid size.
72 | dim3 blockSize, gridSize;
73 | if (spec.tileOutW < 0) // large
74 | {
75 | blockSize = dim3(4, 32, 1);
76 | gridSize = dim3(
77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79 | p.launchMajor);
80 | }
81 | else // small
82 | {
83 | blockSize = dim3(256, 1, 1);
84 | gridSize = dim3(
85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87 | p.launchMajor);
88 | }
89 |
90 | // Launch CUDA kernel.
91 | void* args[] = {&p};
92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93 | return y;
94 | }
95 |
96 | //------------------------------------------------------------------------
97 |
98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99 | {
100 | m.def("upfirdn2d", &upfirdn2d);
101 | }
102 |
103 | //------------------------------------------------------------------------
104 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom PyTorch ops for efficient resampling of 2D images."""
10 |
11 | import os
12 | import sys
13 | import warnings
14 | import numpy as np
15 | import torch
16 | import traceback
17 |
18 | from .. import custom_ops
19 | from .. import misc
20 | from . import conv2d_gradfix
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | _inited = False
25 | _plugin = None
26 | _init_already_failed = False
27 |
28 | def _init():
29 | global _inited, _plugin, _init_already_failed
30 | if _init_already_failed:
31 | return False
32 | if not _inited:
33 | sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
34 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
35 | try:
36 | _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
37 | except:
38 | _init_already_failed = True
39 | warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
40 | return _plugin is not None
41 |
42 | def _parse_scaling(scaling):
43 | if isinstance(scaling, int):
44 | scaling = [scaling, scaling]
45 | assert isinstance(scaling, (list, tuple))
46 | assert all(isinstance(x, int) for x in scaling)
47 | sx, sy = scaling
48 | assert sx >= 1 and sy >= 1
49 | return sx, sy
50 |
51 | def _parse_padding(padding):
52 | if isinstance(padding, int):
53 | padding = [padding, padding]
54 | assert isinstance(padding, (list, tuple))
55 | assert all(isinstance(x, int) for x in padding)
56 | if len(padding) == 2:
57 | padx, pady = padding
58 | padding = [padx, padx, pady, pady]
59 | padx0, padx1, pady0, pady1 = padding
60 | return padx0, padx1, pady0, pady1
61 |
62 | def _get_filter_size(f):
63 | if f is None:
64 | return 1, 1
65 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
66 | fw = f.shape[-1]
67 | fh = f.shape[0]
68 | with misc.suppress_tracer_warnings():
69 | fw = int(fw)
70 | fh = int(fh)
71 | misc.assert_shape(f, [fh, fw][:f.ndim])
72 | assert fw >= 1 and fh >= 1
73 | return fw, fh
74 |
75 | #----------------------------------------------------------------------------
76 |
77 | def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
78 | r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
79 |
80 | Args:
81 | f: Torch tensor, numpy array, or python list of the shape
82 | `[filter_height, filter_width]` (non-separable),
83 | `[filter_taps]` (separable),
84 | `[]` (impulse), or
85 | `None` (identity).
86 | device: Result device (default: cpu).
87 | normalize: Normalize the filter so that it retains the magnitude
88 | for constant input signal (DC)? (default: True).
89 | flip_filter: Flip the filter? (default: False).
90 | gain: Overall scaling factor for signal magnitude (default: 1).
91 | separable: Return a separable filter? (default: select automatically).
92 |
93 | Returns:
94 | Float32 tensor of the shape
95 | `[filter_height, filter_width]` (non-separable) or
96 | `[filter_taps]` (separable).
97 | """
98 | # Validate.
99 | if f is None:
100 | f = 1
101 | f = torch.as_tensor(f, dtype=torch.float32)
102 | assert f.ndim in [0, 1, 2]
103 | assert f.numel() > 0
104 | if f.ndim == 0:
105 | f = f[np.newaxis]
106 |
107 | # Separable?
108 | if separable is None:
109 | separable = (f.ndim == 1 and f.numel() >= 8)
110 | if f.ndim == 1 and not separable:
111 | f = f.ger(f)
112 | assert f.ndim == (1 if separable else 2)
113 |
114 | # Apply normalize, flip, gain, and device.
115 | if normalize:
116 | f /= f.sum()
117 | if flip_filter:
118 | f = f.flip(list(range(f.ndim)))
119 | f = f * (gain ** (f.ndim / 2))
120 | f = f.to(device=device)
121 | return f
122 |
123 | #----------------------------------------------------------------------------
124 |
125 | def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
126 | r"""Pad, upsample, filter, and downsample a batch of 2D images.
127 |
128 | Performs the following sequence of operations for each channel:
129 |
130 | 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
131 |
132 | 2. Pad the image with the specified number of zeros on each side (`padding`).
133 | Negative padding corresponds to cropping the image.
134 |
135 | 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
136 | so that the footprint of all output pixels lies within the input image.
137 |
138 | 4. Downsample the image by keeping every Nth pixel (`down`).
139 |
140 | This sequence of operations bears close resemblance to scipy.signal.upfirdn().
141 | The fused op is considerably more efficient than performing the same calculation
142 | using standard PyTorch ops. It supports gradients of arbitrary order.
143 |
144 | Args:
145 | x: Float32/float64/float16 input tensor of the shape
146 | `[batch_size, num_channels, in_height, in_width]`.
147 | f: Float32 FIR filter of the shape
148 | `[filter_height, filter_width]` (non-separable),
149 | `[filter_taps]` (separable), or
150 | `None` (identity).
151 | up: Integer upsampling factor. Can be a single int or a list/tuple
152 | `[x, y]` (default: 1).
153 | down: Integer downsampling factor. Can be a single int or a list/tuple
154 | `[x, y]` (default: 1).
155 | padding: Padding with respect to the upsampled image. Can be a single number
156 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
157 | (default: 0).
158 | flip_filter: False = convolution, True = correlation (default: False).
159 | gain: Overall scaling factor for signal magnitude (default: 1).
160 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
161 |
162 | Returns:
163 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
164 | """
165 | assert isinstance(x, torch.Tensor)
166 | assert impl in ['ref', 'cuda']
167 | if impl == 'cuda' and x.device.type == 'cuda' and _init():
168 | return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
169 | return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
170 |
171 | #----------------------------------------------------------------------------
172 |
173 | @misc.profiled_function
174 | def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
175 | """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
176 | """
177 | # Validate arguments.
178 | assert isinstance(x, torch.Tensor) and x.ndim == 4
179 | if f is None:
180 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
181 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
182 | assert f.dtype == torch.float32 and not f.requires_grad
183 | batch_size, num_channels, in_height, in_width = x.shape
184 | upx, upy = _parse_scaling(up)
185 | downx, downy = _parse_scaling(down)
186 | padx0, padx1, pady0, pady1 = _parse_padding(padding)
187 |
188 | # Upsample by inserting zeros.
189 | x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
190 | x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
191 | x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
192 |
193 | # Pad or crop.
194 | x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
195 | x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
196 |
197 | # Setup filter.
198 | f = f * (gain ** (f.ndim / 2))
199 | f = f.to(x.dtype)
200 | if not flip_filter:
201 | f = f.flip(list(range(f.ndim)))
202 |
203 | # Convolve with the filter.
204 | f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
205 | if f.ndim == 4:
206 | x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
207 | else:
208 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
209 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
210 |
211 | # Downsample by throwing away pixels.
212 | x = x[:, :, ::downy, ::downx]
213 | return x
214 |
215 | #----------------------------------------------------------------------------
216 |
217 | _upfirdn2d_cuda_cache = dict()
218 |
219 | def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
220 | """Fast CUDA implementation of `upfirdn2d()` using custom ops.
221 | """
222 | # Parse arguments.
223 | upx, upy = _parse_scaling(up)
224 | downx, downy = _parse_scaling(down)
225 | padx0, padx1, pady0, pady1 = _parse_padding(padding)
226 |
227 | # Lookup from cache.
228 | key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
229 | if key in _upfirdn2d_cuda_cache:
230 | return _upfirdn2d_cuda_cache[key]
231 |
232 | # Forward op.
233 | class Upfirdn2dCuda(torch.autograd.Function):
234 | @staticmethod
235 | def forward(ctx, x, f): # pylint: disable=arguments-differ
236 | assert isinstance(x, torch.Tensor) and x.ndim == 4
237 | if f is None:
238 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
239 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
240 | y = x
241 | if f.ndim == 2:
242 | y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
243 | else:
244 | y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
245 | y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
246 | ctx.save_for_backward(f)
247 | ctx.x_shape = x.shape
248 | return y
249 |
250 | @staticmethod
251 | def backward(ctx, dy): # pylint: disable=arguments-differ
252 | f, = ctx.saved_tensors
253 | _, _, ih, iw = ctx.x_shape
254 | _, _, oh, ow = dy.shape
255 | fw, fh = _get_filter_size(f)
256 | p = [
257 | fw - padx0 - 1,
258 | iw * upx - ow * downx + padx0 - upx + 1,
259 | fh - pady0 - 1,
260 | ih * upy - oh * downy + pady0 - upy + 1,
261 | ]
262 | dx = None
263 | df = None
264 |
265 | if ctx.needs_input_grad[0]:
266 | dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
267 |
268 | assert not ctx.needs_input_grad[1]
269 | return dx, df
270 |
271 | # Add to cache.
272 | _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
273 | return Upfirdn2dCuda
274 |
275 | #----------------------------------------------------------------------------
276 |
277 | def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
278 | r"""Filter a batch of 2D images using the given 2D FIR filter.
279 |
280 | By default, the result is padded so that its shape matches the input.
281 | User-specified padding is applied on top of that, with negative values
282 | indicating cropping. Pixels outside the image are assumed to be zero.
283 |
284 | Args:
285 | x: Float32/float64/float16 input tensor of the shape
286 | `[batch_size, num_channels, in_height, in_width]`.
287 | f: Float32 FIR filter of the shape
288 | `[filter_height, filter_width]` (non-separable),
289 | `[filter_taps]` (separable), or
290 | `None` (identity).
291 | padding: Padding with respect to the output. Can be a single number or a
292 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
293 | (default: 0).
294 | flip_filter: False = convolution, True = correlation (default: False).
295 | gain: Overall scaling factor for signal magnitude (default: 1).
296 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
297 |
298 | Returns:
299 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
300 | """
301 | padx0, padx1, pady0, pady1 = _parse_padding(padding)
302 | fw, fh = _get_filter_size(f)
303 | p = [
304 | padx0 + fw // 2,
305 | padx1 + (fw - 1) // 2,
306 | pady0 + fh // 2,
307 | pady1 + (fh - 1) // 2,
308 | ]
309 | return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
310 |
311 | #----------------------------------------------------------------------------
312 |
313 | def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
314 | r"""Upsample a batch of 2D images using the given 2D FIR filter.
315 |
316 | By default, the result is padded so that its shape is a multiple of the input.
317 | User-specified padding is applied on top of that, with negative values
318 | indicating cropping. Pixels outside the image are assumed to be zero.
319 |
320 | Args:
321 | x: Float32/float64/float16 input tensor of the shape
322 | `[batch_size, num_channels, in_height, in_width]`.
323 | f: Float32 FIR filter of the shape
324 | `[filter_height, filter_width]` (non-separable),
325 | `[filter_taps]` (separable), or
326 | `None` (identity).
327 | up: Integer upsampling factor. Can be a single int or a list/tuple
328 | `[x, y]` (default: 1).
329 | padding: Padding with respect to the output. Can be a single number or a
330 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
331 | (default: 0).
332 | flip_filter: False = convolution, True = correlation (default: False).
333 | gain: Overall scaling factor for signal magnitude (default: 1).
334 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
335 |
336 | Returns:
337 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
338 | """
339 | upx, upy = _parse_scaling(up)
340 | padx0, padx1, pady0, pady1 = _parse_padding(padding)
341 | fw, fh = _get_filter_size(f)
342 | p = [
343 | padx0 + (fw + upx - 1) // 2,
344 | padx1 + (fw - upx) // 2,
345 | pady0 + (fh + upy - 1) // 2,
346 | pady1 + (fh - upy) // 2,
347 | ]
348 | return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
349 |
350 | #----------------------------------------------------------------------------
351 |
352 | def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
353 | r"""Downsample a batch of 2D images using the given 2D FIR filter.
354 |
355 | By default, the result is padded so that its shape is a fraction of the input.
356 | User-specified padding is applied on top of that, with negative values
357 | indicating cropping. Pixels outside the image are assumed to be zero.
358 |
359 | Args:
360 | x: Float32/float64/float16 input tensor of the shape
361 | `[batch_size, num_channels, in_height, in_width]`.
362 | f: Float32 FIR filter of the shape
363 | `[filter_height, filter_width]` (non-separable),
364 | `[filter_taps]` (separable), or
365 | `None` (identity).
366 | down: Integer downsampling factor. Can be a single int or a list/tuple
367 | `[x, y]` (default: 1).
368 | padding: Padding with respect to the input. Can be a single number or a
369 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
370 | (default: 0).
371 | flip_filter: False = convolution, True = correlation (default: False).
372 | gain: Overall scaling factor for signal magnitude (default: 1).
373 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
374 |
375 | Returns:
376 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
377 | """
378 | downx, downy = _parse_scaling(down)
379 | padx0, padx1, pady0, pady1 = _parse_padding(padding)
380 | fw, fh = _get_filter_size(f)
381 | p = [
382 | padx0 + (fw - downx + 1) // 2,
383 | padx1 + (fw - downx) // 2,
384 | pady0 + (fh - downy + 1) // 2,
385 | pady1 + (fh - downy) // 2,
386 | ]
387 | return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
388 |
389 | #----------------------------------------------------------------------------
390 |
--------------------------------------------------------------------------------
/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Facilities for pickling Python code alongside other data.
10 |
11 | The pickled code is automatically imported into a separate Python module
12 | during unpickling. This way, any previously exported pickles will remain
13 | usable even if the original code is no longer available, or if the current
14 | version of the code is not consistent with what was originally pickled."""
15 |
16 | import sys
17 | import pickle
18 | import io
19 | import inspect
20 | import copy
21 | import uuid
22 | import types
23 | import dnnlib
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | _version = 6 # internal version number
28 | _decorators = set() # {decorator_class, ...}
29 | _import_hooks = [] # [hook_function, ...]
30 | _module_to_src_dict = dict() # {module: src, ...}
31 | _src_to_module_dict = dict() # {src: module, ...}
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def persistent_class(orig_class):
36 | r"""Class decorator that extends a given class to save its source code
37 | when pickled.
38 |
39 | Example:
40 |
41 | from torch_utils import persistence
42 |
43 | @persistence.persistent_class
44 | class MyNetwork(torch.nn.Module):
45 | def __init__(self, num_inputs, num_outputs):
46 | super().__init__()
47 | self.fc = MyLayer(num_inputs, num_outputs)
48 | ...
49 |
50 | @persistence.persistent_class
51 | class MyLayer(torch.nn.Module):
52 | ...
53 |
54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55 | source code alongside other internal state (e.g., parameters, buffers,
56 | and submodules). This way, any previously exported pickle will remain
57 | usable even if the class definitions have been modified or are no
58 | longer available.
59 |
60 | The decorator saves the source code of the entire Python module
61 | containing the decorated class. It does *not* save the source code of
62 | any imported modules. Thus, the imported modules must be available
63 | during unpickling, also including `torch_utils.persistence` itself.
64 |
65 | It is ok to call functions defined in the same module from the
66 | decorated class. However, if the decorated class depends on other
67 | classes defined in the same module, they must be decorated as well.
68 | This is illustrated in the above example in the case of `MyLayer`.
69 |
70 | It is also possible to employ the decorator just-in-time before
71 | calling the constructor. For example:
72 |
73 | cls = MyLayer
74 | if want_to_make_it_persistent:
75 | cls = persistence.persistent_class(cls)
76 | layer = cls(num_inputs, num_outputs)
77 |
78 | As an additional feature, the decorator also keeps track of the
79 | arguments that were used to construct each instance of the decorated
80 | class. The arguments can be queried via `obj.init_args` and
81 | `obj.init_kwargs`, and they are automatically pickled alongside other
82 | object state. A typical use case is to first unpickle a previous
83 | instance of a persistent class, and then upgrade it to use the latest
84 | version of the source code:
85 |
86 | with open('old_pickle.pkl', 'rb') as f:
87 | old_net = pickle.load(f)
88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90 | """
91 | assert isinstance(orig_class, type)
92 | if is_persistent(orig_class):
93 | return orig_class
94 |
95 | assert orig_class.__module__ in sys.modules
96 | orig_module = sys.modules[orig_class.__module__]
97 | orig_module_src = _module_to_src(orig_module)
98 |
99 | class Decorator(orig_class):
100 | _orig_module_src = orig_module_src
101 | _orig_class_name = orig_class.__name__
102 |
103 | def __init__(self, *args, **kwargs):
104 | super().__init__(*args, **kwargs)
105 | self._init_args = copy.deepcopy(args)
106 | self._init_kwargs = copy.deepcopy(kwargs)
107 | assert orig_class.__name__ in orig_module.__dict__
108 | _check_pickleable(self.__reduce__())
109 |
110 | @property
111 | def init_args(self):
112 | return copy.deepcopy(self._init_args)
113 |
114 | @property
115 | def init_kwargs(self):
116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
117 |
118 | def __reduce__(self):
119 | fields = list(super().__reduce__())
120 | fields += [None] * max(3 - len(fields), 0)
121 | if fields[0] is not _reconstruct_persistent_obj:
122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
123 | fields[0] = _reconstruct_persistent_obj # reconstruct func
124 | fields[1] = (meta,) # reconstruct args
125 | fields[2] = None # state dict
126 | return tuple(fields)
127 |
128 | Decorator.__name__ = orig_class.__name__
129 | _decorators.add(Decorator)
130 | return Decorator
131 |
132 | #----------------------------------------------------------------------------
133 |
134 | def is_persistent(obj):
135 | r"""Test whether the given object or class is persistent, i.e.,
136 | whether it will save its source code when pickled.
137 | """
138 | try:
139 | if obj in _decorators:
140 | return True
141 | except TypeError:
142 | pass
143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144 |
145 | #----------------------------------------------------------------------------
146 |
147 | def import_hook(hook):
148 | r"""Register an import hook that is called whenever a persistent object
149 | is being unpickled. A typical use case is to patch the pickled source
150 | code to avoid errors and inconsistencies when the API of some imported
151 | module has changed.
152 |
153 | The hook should have the following signature:
154 |
155 | hook(meta) -> modified meta
156 |
157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158 |
159 | type: Type of the persistent object, e.g. `'class'`.
160 | version: Internal version number of `torch_utils.persistence`.
161 | module_src Original source code of the Python module.
162 | class_name: Class name in the original Python module.
163 | state: Internal state of the object.
164 |
165 | Example:
166 |
167 | @persistence.import_hook
168 | def wreck_my_network(meta):
169 | if meta.class_name == 'MyNetwork':
170 | print('MyNetwork is being imported. I will wreck it!')
171 | meta.module_src = meta.module_src.replace("True", "False")
172 | return meta
173 | """
174 | assert callable(hook)
175 | _import_hooks.append(hook)
176 |
177 | #----------------------------------------------------------------------------
178 |
179 | def _reconstruct_persistent_obj(meta):
180 | r"""Hook that is called internally by the `pickle` module to unpickle
181 | a persistent object.
182 | """
183 | meta = dnnlib.EasyDict(meta)
184 | meta.state = dnnlib.EasyDict(meta.state)
185 | for hook in _import_hooks:
186 | meta = hook(meta)
187 | assert meta is not None
188 |
189 | assert meta.version == _version
190 | module = _src_to_module(meta.module_src)
191 |
192 | assert meta.type == 'class'
193 | orig_class = module.__dict__[meta.class_name]
194 | decorator_class = persistent_class(orig_class)
195 | obj = decorator_class.__new__(decorator_class)
196 |
197 | setstate = getattr(obj, '__setstate__', None)
198 | if callable(setstate):
199 | setstate(meta.state) # pylint: disable=not-callable
200 | else:
201 | obj.__dict__.update(meta.state)
202 | return obj
203 |
204 | #----------------------------------------------------------------------------
205 |
206 | def _module_to_src(module):
207 | r"""Query the source code of a given Python module.
208 | """
209 | src = _module_to_src_dict.get(module, None)
210 | if src is None:
211 | src = inspect.getsource(module)
212 | _module_to_src_dict[module] = src
213 | _src_to_module_dict[src] = module
214 | return src
215 |
216 | def _src_to_module(src):
217 | r"""Get or create a Python module for the given source code.
218 | """
219 | module = _src_to_module_dict.get(src, None)
220 | if module is None:
221 | module_name = "_imported_module_" + uuid.uuid4().hex
222 | module = types.ModuleType(module_name)
223 | sys.modules[module_name] = module
224 | _module_to_src_dict[module] = src
225 | _src_to_module_dict[src] = module
226 | exec(src, module.__dict__) # pylint: disable=exec-used
227 | return module
228 |
229 | #----------------------------------------------------------------------------
230 |
231 | def _check_pickleable(obj):
232 | r"""Check that the given object is pickleable, raising an exception if
233 | it is not. This function is expected to be considerably more efficient
234 | than actually pickling the object.
235 | """
236 | def recurse(obj):
237 | if isinstance(obj, (list, tuple, set)):
238 | return [recurse(x) for x in obj]
239 | if isinstance(obj, dict):
240 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242 | return None # Python primitive types are pickleable.
243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
244 | return None # NumPy arrays and PyTorch tensors are pickleable.
245 | if is_persistent(obj):
246 | return None # Persistent objects are pickleable, by virtue of the constructor check.
247 | return obj
248 | with io.BytesIO() as f:
249 | pickle.dump(recurse(obj), f)
250 |
251 | #----------------------------------------------------------------------------
252 |
--------------------------------------------------------------------------------
/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Facilities for reporting and collecting training statistics across
10 | multiple processes and devices. The interface is designed to minimize
11 | synchronization overhead as well as the amount of boilerplate in user
12 | code."""
13 |
14 | import re
15 | import numpy as np
16 | import torch
17 | import dnnlib
18 |
19 | from . import misc
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
25 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
26 | _rank = 0 # Rank of the current process.
27 | _sync_device = None # Device to use for multiprocess communication. None = single-process.
28 | _sync_called = False # Has _sync() been called yet?
29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def init_multiprocessing(rank, sync_device):
35 | r"""Initializes `torch_utils.training_stats` for collecting statistics
36 | across multiple processes.
37 |
38 | This function must be called after
39 | `torch.distributed.init_process_group()` and before `Collector.update()`.
40 | The call is not necessary if multi-process collection is not needed.
41 |
42 | Args:
43 | rank: Rank of the current process.
44 | sync_device: PyTorch device to use for inter-process
45 | communication, or None to disable multi-process
46 | collection. Typically `torch.device('cuda', rank)`.
47 | """
48 | global _rank, _sync_device
49 | assert not _sync_called
50 | _rank = rank
51 | _sync_device = sync_device
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | @misc.profiled_function
56 | def report(name, value):
57 | r"""Broadcasts the given set of scalars to all interested instances of
58 | `Collector`, across device and process boundaries.
59 |
60 | This function is expected to be extremely cheap and can be safely
61 | called from anywhere in the training loop, loss function, or inside a
62 | `torch.nn.Module`.
63 |
64 | Warning: The current implementation expects the set of unique names to
65 | be consistent across processes. Please make sure that `report()` is
66 | called at least once for each unique name by each process, and in the
67 | same order. If a given process has no scalars to broadcast, it can do
68 | `report(name, [])` (empty list).
69 |
70 | Args:
71 | name: Arbitrary string specifying the name of the statistic.
72 | Averages are accumulated separately for each unique name.
73 | value: Arbitrary set of scalars. Can be a list, tuple,
74 | NumPy array, PyTorch tensor, or Python scalar.
75 |
76 | Returns:
77 | The same `value` that was passed in.
78 | """
79 | if name not in _counters:
80 | _counters[name] = dict()
81 |
82 | elems = torch.as_tensor(value)
83 | if elems.numel() == 0:
84 | return value
85 |
86 | elems = elems.detach().flatten().to(_reduce_dtype)
87 | moments = torch.stack([
88 | torch.ones_like(elems).sum(),
89 | elems.sum(),
90 | elems.square().sum(),
91 | ])
92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
93 | moments = moments.to(_counter_dtype)
94 |
95 | device = moments.device
96 | if device not in _counters[name]:
97 | _counters[name][device] = torch.zeros_like(moments)
98 | _counters[name][device].add_(moments)
99 | return value
100 |
101 | #----------------------------------------------------------------------------
102 |
103 | def report0(name, value):
104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
105 | but ignores any scalars provided by the other processes.
106 | See `report()` for further details.
107 | """
108 | report(name, value if _rank == 0 else [])
109 | return value
110 |
111 | #----------------------------------------------------------------------------
112 |
113 | class Collector:
114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
115 | computes their long-term averages (mean and standard deviation) over
116 | user-defined periods of time.
117 |
118 | The averages are first collected into internal counters that are not
119 | directly visible to the user. They are then copied to the user-visible
120 | state as a result of calling `update()` and can then be queried using
121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
122 | internal counters for the next round, so that the user-visible state
123 | effectively reflects averages collected between the last two calls to
124 | `update()`.
125 |
126 | Args:
127 | regex: Regular expression defining which statistics to
128 | collect. The default is to collect everything.
129 | keep_previous: Whether to retain the previous averages if no
130 | scalars were collected on a given round
131 | (default: True).
132 | """
133 | def __init__(self, regex='.*', keep_previous=True):
134 | self._regex = re.compile(regex)
135 | self._keep_previous = keep_previous
136 | self._cumulative = dict()
137 | self._moments = dict()
138 | self.update()
139 | self._moments.clear()
140 |
141 | def names(self):
142 | r"""Returns the names of all statistics broadcasted so far that
143 | match the regular expression specified at construction time.
144 | """
145 | return [name for name in _counters if self._regex.fullmatch(name)]
146 |
147 | def update(self):
148 | r"""Copies current values of the internal counters to the
149 | user-visible state and resets them for the next round.
150 |
151 | If `keep_previous=True` was specified at construction time, the
152 | operation is skipped for statistics that have received no scalars
153 | since the last update, retaining their previous averages.
154 |
155 | This method performs a number of GPU-to-CPU transfers and one
156 | `torch.distributed.all_reduce()`. It is intended to be called
157 | periodically in the main training loop, typically once every
158 | N training steps.
159 | """
160 | if not self._keep_previous:
161 | self._moments.clear()
162 | for name, cumulative in _sync(self.names()):
163 | if name not in self._cumulative:
164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
165 | delta = cumulative - self._cumulative[name]
166 | self._cumulative[name].copy_(cumulative)
167 | if float(delta[0]) != 0:
168 | self._moments[name] = delta
169 |
170 | def _get_delta(self, name):
171 | r"""Returns the raw moments that were accumulated for the given
172 | statistic between the last two calls to `update()`, or zero if
173 | no scalars were collected.
174 | """
175 | assert self._regex.fullmatch(name)
176 | if name not in self._moments:
177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
178 | return self._moments[name]
179 |
180 | def num(self, name):
181 | r"""Returns the number of scalars that were accumulated for the given
182 | statistic between the last two calls to `update()`, or zero if
183 | no scalars were collected.
184 | """
185 | delta = self._get_delta(name)
186 | return int(delta[0])
187 |
188 | def mean(self, name):
189 | r"""Returns the mean of the scalars that were accumulated for the
190 | given statistic between the last two calls to `update()`, or NaN if
191 | no scalars were collected.
192 | """
193 | delta = self._get_delta(name)
194 | if int(delta[0]) == 0:
195 | return float('nan')
196 | return float(delta[1] / delta[0])
197 |
198 | def std(self, name):
199 | r"""Returns the standard deviation of the scalars that were
200 | accumulated for the given statistic between the last two calls to
201 | `update()`, or NaN if no scalars were collected.
202 | """
203 | delta = self._get_delta(name)
204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
205 | return float('nan')
206 | if int(delta[0]) == 1:
207 | return float(0)
208 | mean = float(delta[1] / delta[0])
209 | raw_var = float(delta[2] / delta[0])
210 | return np.sqrt(max(raw_var - np.square(mean), 0))
211 |
212 | def as_dict(self):
213 | r"""Returns the averages accumulated between the last two calls to
214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
215 |
216 | dnnlib.EasyDict(
217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
218 | ...
219 | )
220 | """
221 | stats = dnnlib.EasyDict()
222 | for name in self.names():
223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
224 | return stats
225 |
226 | def __getitem__(self, name):
227 | r"""Convenience getter.
228 | `collector[name]` is a synonym for `collector.mean(name)`.
229 | """
230 | return self.mean(name)
231 |
232 | #----------------------------------------------------------------------------
233 |
234 | def _sync(names):
235 | r"""Synchronize the global cumulative counters across devices and
236 | processes. Called internally by `Collector.update()`.
237 | """
238 | if len(names) == 0:
239 | return []
240 | global _sync_called
241 | _sync_called = True
242 |
243 | # Collect deltas within current rank.
244 | deltas = []
245 | device = _sync_device if _sync_device is not None else torch.device('cpu')
246 | for name in names:
247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
248 | for counter in _counters[name].values():
249 | delta.add_(counter.to(device))
250 | counter.copy_(torch.zeros_like(counter))
251 | deltas.append(delta)
252 | deltas = torch.stack(deltas)
253 |
254 | # Sum deltas across ranks.
255 | if _sync_device is not None:
256 | torch.distributed.all_reduce(deltas)
257 |
258 | # Update cumulative values.
259 | deltas = deltas.cpu()
260 | for idx, name in enumerate(names):
261 | if name not in _cumulative:
262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
263 | _cumulative[name].add_(deltas[idx])
264 |
265 | # Return name-value pairs.
266 | return [(name, _cumulative[name]) for name in names]
267 |
268 | #----------------------------------------------------------------------------
269 |
--------------------------------------------------------------------------------
/w_directions.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | import torch
7 |
8 | # The second number is the W dimension up to which the change applies. 8 is a good default.
9 | # (Lower number means only low level features will be affected)
10 | # The expected format is a numpy array of the shape (18, 512).
11 |
12 |
13 | # Synthetic because they are just found by mucking around in Z space.
14 | synthetic_glasses = torch.zeros((1, 512, ))
15 | synthetic_glasses[:, 30] = -1
16 | synthetic_glasses[:, 36] = -1
17 |
18 | # Empirical glasses I manually labelled about 200 images and took the difference in means
19 | # as direction.
20 |
21 | # Beard was found through gradient descent modification and subsequent PCA of differences.
22 |
23 | def known_directions():
24 | return {
25 | 'smile': ('directions/smile_styleganv2.npy', 0, 5),
26 | 'age': ('directions/age.npy', 0, 8),
27 | 'gender': ('directions/gender.npy', 0, 5),
28 | 'eyes_open': ('directions/eyes_open.npy', 0, 12),
29 | 'nose_ratio': ('directions/nose_ratio.npy', 0, 8),
30 | 'nose_tip': ('directions/nose_tip.npy', 0, 12),
31 | 'lip_ratio': ('directions/lip_ratio.npy', 0, 12),
32 | 'eye_distance': ('directions/eye_distance.npy', 0, 8),
33 | 'eye_to_eyebrow_distance': ('directions/eye_eyebrow_distance.npy', 0, 5),
34 | 'eye_ratio': ('directions/eye_ratio.npy', 0, 8),
35 | 'mouth_open': ('directions/mouth_open.npy', 0, 8),
36 | 'pitch': ('directions/pitch.npy', 0, 5),
37 | # 'roll': ('directions/roll.npy', 5),
38 | 'synthetic_glasses': (synthetic_glasses, 2, 4),
39 | 'empirical_glasses': ('directions/empirical_glasses.npy', 2, 4),
40 | 'beard': ('directions/beard.npy', 6, 9),
41 | 'yaw': ('directions/yaw.npy', 0, 5),
42 | 'light': ('directions/light.npy', 5, 11),
43 | 'eye_color': ('directions/light.npy', 14, 15),
44 | # 'eig_all_0': (torch.load('components/all.pt')[:, 0].view(1, 512), 0, 18),
45 | # 'eig_all_1': (torch.load('components/all.pt')[:, 1].view(1, 512), 0, 18),
46 | # 'eig_all_2': (torch.load('components/all.pt')[:, 2].view(1, 512), 0, 18),
47 | # 'eig_all_3': (torch.load('components/all.pt')[:, 3].view(1, 512), 0, 18),
48 | # 'eig_all_4': (torch.load('components/all.pt')[:, 4].view(1, 512), 0, 18),
49 | # 'eig_all_5': (torch.load('components/all.pt')[:, 5].view(1, 512), 0, 18),
50 | # 'eig_all_6': (torch.load('components/all.pt')[:, 6].view(1, 512), 0, 18),
51 | # 'eig_all_7': (torch.load('components/all.pt')[:, 7].view(1, 512), 0, 18),
52 | # 'eig_all_8': (torch.load('components/all.pt')[:, 8].view(1, 512), 0, 18),
53 | # 'eig_all_9': (torch.load('components/all.pt')[:, 9].view(1, 512), 0, 18),
54 | # 'eig_all_10': (torch.load('components/all.pt')[:, 10].view(1, 512), 0, 18),
55 |
56 | # 'eig_64_0': (torch.load('components/b64_conv0.pt')[:, 0].view(1, 512), 0, 18),
57 | # 'eig_64_1': (torch.load('components/b64_conv0.pt')[:, 1].view(1, 512), 0, 18),
58 | # 'eig_64_2': (torch.load('components/b64_conv0.pt')[:, 2].view(1, 512), 0, 18),
59 | # 'eig_64_3': (torch.load('components/b64_conv0.pt')[:, 3].view(1, 512), 0, 18),
60 | # 'eig_64_4': (torch.load('components/b64_conv0.pt')[:, 4].view(1, 512), 0, 18),
61 | # 'eig_64_5': (torch.load('components/b64_conv0.pt')[:, 5].view(1, 512), 0, 18),
62 | # 'eig_64_6': (torch.load('components/b64_conv0.pt')[:, 6].view(1, 512), 0, 18),
63 | # 'eig_64_7': (torch.load('components/b64_conv0.pt')[:, 7].view(1, 512), 0, 18),
64 | # 'eig_64_8': (torch.load('components/b64_conv0.pt')[:, 8].view(1, 512), 0, 18),
65 | # 'eig_64_9': (torch.load('components/b64_conv0.pt')[:, 9].view(1, 512), 0, 18),
66 | # 'eig_64_10': (torch.load('components/b64_conv0.pt')[:, 10].view(1, 512), 0, 18),
67 |
68 | # 'eig_16_0': (torch.load('components/b16_conv0.pt')[:, 0].view(1, 512), 0, 18), # Hair ?
69 | # 'eig_16_1': (torch.load('components/b16_conv0.pt')[:, 1].view(1, 512), 0, 18), # Gender ?
70 | # 'eig_16_2': (torch.load('components/b16_conv0.pt')[:, 2].view(1, 512), 0, 18),
71 | # 'eig_16_3': (torch.load('components/b16_conv0.pt')[:, 3].view(1, 512), 0, 18), # Roll
72 | # 'eig_16_4': (torch.load('components/b16_conv0.pt')[:, 4].view(1, 512), 0, 11), # Age
73 | # 'eig_16_5': (torch.load('components/b16_conv0.pt')[:, 5].view(1, 512), 0, 18),
74 | # 'eig_16_6': (torch.load('components/b16_conv0.pt')[:, 6].view(1, 512), 0, 18),
75 | # 'eig_16_7': (torch.load('components/b16_conv0.pt')[:, 7].view(1, 512), 0, 18),
76 | # 'eig_16_8': (torch.load('components/b16_conv0.pt')[:, 8].view(1, 512), 0, 18),
77 | # 'eig_16_9': (torch.load('components/b16_conv0.pt')[:, 9].view(1, 512), 0, 18), # Weird age
78 | # 'eig_16_10': (torch.load('components/b16_conv0.pt')[:, 10].view(1, 512), 0, 18), # Pitch
79 | }
--------------------------------------------------------------------------------
/widgets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/widgets/__init__.py
--------------------------------------------------------------------------------
/widgets/mask_painter.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | import numpy as np
7 |
8 | from qtutil import *
9 |
10 | from PyQt5.QtWidgets import *
11 | from PyQt5.QtGui import *
12 | from PyQt5.QtCore import *
13 |
14 | from time import perf_counter as time
15 | from _C_canvas import canvas_to_masks
16 |
17 |
18 | # This is inefficient, but it was the quickest solution to the problem.
19 | # Ideally we would just paint along a path.
20 | # Alternative: use QCursor and scaled Pixmap, but it has its own set of problems.
21 | class EllipseBrush:
22 | def __init__(self, height, width, visual_color=QColor.fromRgb(25, 127, 255)):
23 | self.image = QImage(width, height, QImage.Format_RGBA8888)
24 | self.image.fill(Qt.transparent)
25 | self.visual_color = visual_color
26 | self.pen = QPen(self.visual_color)
27 | self.pen.setWidth(3)
28 | self.pen.setCapStyle(Qt.RoundCap)
29 |
30 | self.clear_pen = QPen(Qt.transparent)
31 | self.clear_pen.setWidth(6)
32 | self.clear_pen.setCapStyle(Qt.RoundCap)
33 | self.last_visual_arguments = None
34 |
35 | def clear_last_visual(self, into_painter=None):
36 | # very blunt method, could just draw a big transparent ellipse
37 | # self.image.fill(Qt.transparent)
38 | if self.last_visual_arguments is not None:
39 | if into_painter is None:
40 | painter = QPainter(self.image)
41 | else:
42 | painter = into_painter
43 | painter.setCompositionMode(QPainter.CompositionMode_Clear)
44 | painter.setRenderHints( QPainter.Antialiasing ) # AA for the brush layer
45 | painter.setPen(self.clear_pen)
46 |
47 | args = self.last_visual_arguments
48 | args[2] += 3; args[3] += 3
49 | painter.setBrush(Qt.transparent)
50 | painter.drawEllipse(*args)
51 | if into_painter is None:
52 | painter.end()
53 |
54 | def update_visual(self, mx, my, size):
55 | painter = QPainter(self.image)
56 | self.clear_last_visual(painter)
57 | painter.setCompositionMode(QPainter.CompositionMode_Source)
58 |
59 | painter.setRenderHints( QPainter.Antialiasing ) # AA for the brush layer
60 | painter.setPen(self.pen)
61 | half_radius = float(size) / 2
62 | self.last_visual_arguments = [mx * self.image.width() - half_radius, my * self.image.height() - half_radius, size, size]
63 | painter.drawEllipse(*self.last_visual_arguments)
64 | painter.end()
65 |
66 | def draw(self, qpainter, x, y, color, size):
67 | half_radius = float(size) / 2
68 | qpainter.setBrush(color)
69 | qpainter.setPen(color)
70 | qpainter.drawEllipse(x - half_radius, y - half_radius, size, size)
71 |
72 | # Careful about choice of colors, because there is a clear bug in Qt where the alpha is NOT entirely ignored even with the correct
73 | # CompositionMode.
74 | COLORS = [
75 | (0, 0, 0, 0),
76 | (128, 128, 128, 90), # 1
77 | (128, 255, 128, 90), # 2
78 | (255, 128, 128, 90), # 3
79 | (128, 128, 255, 90), # 4
80 | (150, 51 , 201, 90), # 5
81 | (201, 150, 51 , 90), # 6
82 | (60 , 119, 150, 90), # 7
83 | (77 , 255, 238, 90), # 8
84 | (255, 244, 20 , 90), # 9
85 | (6 , 68 , 17 , 90), # 10
86 | (196, 11 , 0 , 90), # 11
87 | (99 , 48 , 11 , 90), # 12
88 | ]
89 |
90 | # Don't store the buf!
91 | mask_buf = None
92 | def canvas_to_numpy_mask(qt_image, colors):
93 | global mask_buf
94 | asnp = qim2np(qt_image, swapaxes=False)
95 |
96 | # This is the numpy way:
97 | # color_cmps = np_colors[None, None, :, :]
98 | # # (1024, 1024, 4, 1) == (1, 1, 4, 8)
99 | # cmps = np.equal(asnp[:, :, :, None], color_cmps, out=cmp_buf)
100 | # logical_masks = cmps.all(axis=2)[:, :, 1:]
101 | # masks = logical_masks
102 |
103 | # C++ optimized way:
104 |
105 | if mask_buf is None:
106 | masks = canvas_to_masks(asnp, colors)
107 | mask_buf = masks
108 | else:
109 | masks = canvas_to_masks(asnp, colors, output_buffer=mask_buf)
110 | return masks[:, :, 1:]
111 |
112 | # Don't store the buffers anywhere!
113 | canvas_buf = None
114 | def numpy_mask_to_numpy_canvas(numpy_mask):
115 | global canvas_buf
116 | if canvas_buf is None: canvas_buf = np.zeros((numpy_mask.shape[0], numpy_mask.shape[1], 4), dtype=np.uint8)
117 | else: pass # canvas_buf.fill(0.)
118 | # Use COLORS once there are multiple masks
119 |
120 | # The mask needs to have the index of the resulting color & be flattened
121 | numpy_mask = numpy_mask[:, :, :] * (np.arange(1, numpy_mask.shape[2] + 1))[None, None, :]
122 | numpy_mask = numpy_mask.sum(axis=2)
123 |
124 | np.choose(
125 | numpy_mask.astype(np.uint8)[:, :, None],
126 | COLORS,
127 | mode='clip',
128 | out=canvas_buf
129 | )
130 | return canvas_buf
131 |
132 | class PainterWidget:
133 | def __init__(self, height, width, max_segments, reporting_fn):
134 | self.show_brush = False
135 | self.brush_size = 15
136 | self.brush = EllipseBrush(height, width)
137 | self.canvas = QImage(width, height, QImage.Format_RGBA8888)
138 | self.canvas.fill(Qt.transparent)
139 | self.update_callbacks = []
140 | self.color_index = 1
141 | self.last_known_paint_pos = None
142 | self.reporting_fn = reporting_fn
143 | self.max_segments = max_segments
144 | self.np_colors = np.array(COLORS).astype(np.uint8)[:self.max_segments+1]
145 | if self.max_segments > len(COLORS) - 2:
146 | print(f'Requested {self.max_segments} masks, but only {len(COLORS) - 2} colors are defined. Setting max segments to {len(COLORS) - 2}')
147 | self.max_segments = len(COLORS) - 2
148 |
149 | self._hide = False
150 | self._enabled = True
151 | self._active_buttons = {}
152 |
153 | def get_volatile_masks(self):
154 | asnp = qim2np(self.canvas, swapaxes=False)
155 | masks = canvas_to_masks(asnp, self.np_colors)
156 | return asnp, masks
157 |
158 | def draw(self, qpainter, qevent):
159 | if self._hide: return
160 | # this actually only draws the overlay with brush
161 | qpainter.setOpacity(1)
162 | # This overlays current context with the actual painting content
163 | qpainter.drawImage(qevent.rect(), self.canvas)
164 | # This overlays the current context with the brush visual
165 | qpainter.drawImage(qevent.rect(), self.brush.image)
166 |
167 | def change_brush_size(self, amount):
168 | self.brush_size += amount
169 | self.brush_size = 1 if self.brush_size < 1 else self.brush_size
170 | self.update_visuals()
171 |
172 | def change_active_color(self, dst_idx):
173 | self.color_index = dst_idx
174 | # Active color is between 1 and N
175 | if self.color_index < 1: self.color_index = 1
176 | if self.color_index >= min(self.max_segments + 1, len(COLORS)): self.color_index = min(self.max_segments + 1, len(COLORS)) - 1
177 | self.reporting_fn(f"Current color index: {self.color_index} - {{{COLORS[self.color_index]}}}")
178 |
179 | def active_color(self):
180 | return QColor.fromRgb(*COLORS[self.color_index])
181 |
182 | def rewrite_with_numpy_mask(self, target, writeback=False, actor=None):
183 | # Changes was initiated by gui and this is just a notification of change
184 | if actor == 'gui': return
185 |
186 | old_canvas = qim2np(self.canvas)
187 | old_canvas[:] = numpy_mask_to_numpy_canvas(target)[:]
188 | self.canvas = np2qim(old_canvas, do_copy=False)
189 | if writeback:
190 | for c in self.update_callbacks:
191 | c(target, actor='gui')
192 |
193 | # dx, dy in [0,1]
194 | # called on each mouse event, even when button is not down
195 | def paint_at(self, button, dy, dx):
196 | self.last_known_paint_pos = (dx, dy)
197 | self.update_visuals()
198 |
199 | csx, csy = dx * self.canvas.width(), dy * self.canvas.height()
200 | if button == MouseButton.LEFT:
201 | self._paint_at_pos(csx, csy, self.active_color())
202 | if button == MouseButton.RIGHT:
203 | self._paint_at_pos(csx, csy, QColor.fromRgb(*COLORS[0]))
204 | if button == MouseButton.MIDDLE:
205 | self._pick_color(csx, csy)
206 |
207 | def update_mouse_position(self, dy, dx):
208 | self.last_known_paint_pos = (dx, dy)
209 | if self.enabled():
210 | self.update_visuals()
211 |
212 | def update_visuals(self):
213 | if self.last_known_paint_pos is None: return
214 | self.brush.update_visual(self.last_known_paint_pos[0], self.last_known_paint_pos[1], self.brush_size)
215 |
216 | def clear_visuals(self):
217 | self.brush.clear_last_visual()
218 |
219 | def _paint_at_pos(self, x, y, color):
220 | qpainter = QPainter(self.canvas)
221 | qpainter.setCompositionMode(QPainter.CompositionMode_Source) # composition mode: Source (to prevent alpha multiplication)
222 | self.brush.draw(qpainter, x, y, color, self.brush_size)
223 | qpainter.end()
224 | # duplicate changes into the writeback buffer
225 | # this makes a copy after each paint call !
226 | # Notify all interested objects.
227 | numpy_mask = canvas_to_numpy_mask(self.canvas, self.np_colors)
228 | for c in self.update_callbacks:
229 | c(numpy_mask, actor='gui')
230 |
231 | def _pick_color(self, x, y):
232 | c = self.canvas.pixelColor(x, y)
233 | idx = None
234 | for i, ref_c in enumerate(COLORS):
235 | if c.getRgb() == ref_c:
236 | idx = i
237 | if idx is None: raise RuntimeError(f"Color index not found (color has wrong representation of {c.getRgb()})")
238 | self.change_active_color(idx)
239 |
240 | # Enabled will be queried for event forwarding.
241 | def enabled(self):
242 | return self._enabled
243 |
244 | def toggle(self, to):
245 | self._enabled = to
246 | if (not to):
247 | # If enabled = False then draw will no longer be called and visual might stick around.
248 | self.clear_visuals()
249 | else:
250 | # if enabled = True then visual is not drawn and mouse position is not updated.
251 | self.update_visuals()
252 |
253 | def hide(self):
254 | self._hide = True
255 |
256 | def show(self):
257 | self._hide = False
258 |
259 | def mouse_enter(self):
260 | pass
261 |
262 | def mouse_leave(self):
263 | self.clear_visuals()
264 |
265 | def mouse_zoom(self, angle):
266 | modifiers = QApplication.keyboardModifiers()
267 | if modifiers == Qt.ShiftModifier:
268 | self.change_active_color(self.color_index + 1 if angle > 0 else self.color_index - 1)
269 | else:
270 | self.change_brush_size(angle)
271 |
272 | def mouse_down(self, y, x, btn):
273 | self._active_buttons[btn] = True
274 | self.paint_at(btn, y, x)
275 |
276 | def mouse_up(self, y, x, btn):
277 | self._active_buttons[btn] = None
278 |
279 | def mouse_moved(self, y, x):
280 | if self.enabled():
281 | for (btn, active) in self._active_buttons.items():
282 | if active: self.paint_at(btn, y, x)
283 | self.update_mouse_position(y, x)
284 |
285 | def key_down(self, key):
286 | if key == '+':
287 | self.change_active_color(self.color_index + 1)
288 | if key == '-':
289 | self.change_active_color(self.color_index - 1)
290 |
--------------------------------------------------------------------------------
/widgets/workspace.py:
--------------------------------------------------------------------------------
1 | #
2 | # Author: David Futschik
3 | # Provided as part of the Chunkmogrify project, 2021.
4 | #
5 |
6 | import numpy as np
7 |
8 | from PyQt5.QtWidgets import *
9 | from PyQt5.QtGui import *
10 | from PyQt5.QtCore import *
11 |
12 | from qtutil import *
13 |
14 | class KeepArWidget(QWidget):
15 | def __init__(self, parent=None):
16 | super().__init__(parent)
17 | self.ar = 1. # 16 / 9
18 |
19 | def set_widget(self, widget):
20 | self.setLayout(QBoxLayout(QBoxLayout.LeftToRight, self))
21 | self.layout().addItem(QSpacerItem(0, 0))
22 | self.layout().addWidget(widget)
23 | self.layout().addItem(QSpacerItem(0, 0))
24 |
25 | def set_ar(self, ar):
26 | self.ar = ar
27 |
28 | def resizeEvent(self, event):
29 | w, h = event.size().width(), event.size().height()
30 |
31 | if w / h > self.ar:
32 | self.layout().setDirection(QBoxLayout.LeftToRight)
33 | widget_stretch = h * self.ar
34 | outer_stretch = (w - widget_stretch) / 2 + 0.5
35 | else:
36 | self.layout().setDirection(QBoxLayout.TopToBottom)
37 | widget_stretch = w / self.ar
38 | outer_stretch = (h - widget_stretch) / 2 + 0.5
39 |
40 | self.layout().setStretch(0, outer_stretch)
41 | self.layout().setStretch(1, widget_stretch)
42 | self.layout().setStretch(2, outer_stretch)
43 |
44 |
45 | class WorkspaceWidget(QWidget):
46 | def __init__(self, app, initial_image, forward_widgets=[], parent=None):
47 | super().__init__(parent)
48 | self.pixmap = None
49 | self.resize_hint = 2048
50 | self.aspect_ratio = 1. # 16. / 9.
51 | self.resize_to_aspect = True
52 | self.painting_enabled = True
53 | self.setLayout(QBoxLayout(QBoxLayout.LeftToRight, self))
54 |
55 | self.empty_pixmap_label = "Start by opening an image file. Go to "
56 | if initial_image:
57 | self.content = npy_loader(initial_image)
58 | self.has_content = True
59 | else:
60 | self.content = np.zeros((1024, 1024, 4))
61 | hbox = QHBoxLayout()
62 | hbox.addStretch()
63 | hbox.addWidget(QLabel(self.empty_pixmap_label))
64 | hbox.addStretch()
65 | self.layout().addLayout(hbox)
66 | self.placeholder_content = hbox
67 | self.has_content = False
68 |
69 | self.overlay = None
70 | self.overlay_enabled = True
71 | self.forward_widgets = forward_widgets
72 |
73 | self.inset_border_params = None
74 | self.scale_factor = 1
75 |
76 | policy = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred)
77 | policy.setHeightForWidth(True)
78 | self.setSizePolicy(policy)
79 |
80 | self._update_image()
81 |
82 | def set_content(self, image: np.ndarray, update=True):
83 | assert image.dtype == np.uint8, "Workspace widget expects content to be a uint8 numpy array."
84 | self.content = image
85 | if not self.has_content: # Did not previously have content and showed placeholder.
86 | self.layout().removeItem(self.placeholder_content)
87 | destroy_layout(self.placeholder_content)
88 | self.has_content = True
89 | if update: self._update_image()
90 |
91 | def set_overlay(self, overlay: np.ndarray, update=True):
92 | assert overlay.dtype == np.uint8, "Workspace widget expects overlay to be a uint8 numpy array."
93 | assert overlay.shape[2] == 4, "Workspace overlay must be a 4 dimensional image."
94 | self.overlay = overlay
95 | if update: self._update_image()
96 |
97 | def _update_image(self):
98 | if self.overlay is not None and self.overlay_enabled:
99 | alpha = self.overlay[:, :, 3:3]
100 | with_overlay = self.overlay * alpha + (1 - alpha) * self.content[:, :, 0:3]
101 | image = with_overlay
102 | else:
103 | image = self.content
104 | self.pixmap = QPixmap(np2qim(image))
105 | self.update()
106 |
107 | def toggle_overlay(self, val):
108 | self.overlay_enabled = val
109 | self._update_image()
110 |
111 | def current_image_as_numpy(self):
112 | return qim2np(self.pixmap.toImage()).copy()
113 |
114 | def get_current(self, with_overlay):
115 | npy = self.content[:, :, 0:3]
116 | if with_overlay and self.overlay is not None:
117 | alpha = self.overlay[:, :, 3:3]
118 | npy = self.overlay * alpha + (1 - alpha) * self.content[:, :, 0:3]
119 | return npy
120 |
121 | def inset_border(self, pxs, color):
122 | self.inset_border_params = {
123 | 'px': pxs,
124 | 'color': color
125 | }
126 | self.update()
127 |
128 | def set_scale_factor(self, factor):
129 | self.scale_factor = factor
130 | self.update()
131 |
132 | def paintEvent(self, event):
133 | pixmap = self.pixmap
134 |
135 | qpainter = QPainter(self)
136 | qpainter.drawPixmap(event.rect(), pixmap)
137 | if self.inset_border_params:
138 | rect = event.rect()
139 | rect.adjust(0, 0, -1, -1)
140 | qpainter.setPen(
141 | QPen(
142 | QColor(*self.inset_border_params['color'], 255),
143 | self.inset_border_params['px'])
144 | )
145 | qpainter.drawRect(rect)
146 |
147 | for w in self.forward_widgets:
148 | w.draw(qpainter, event)
149 |
150 | if self.resize_to_aspect:
151 | self.resize_to_aspect = False
152 | w, h = self.resize_to()
153 | qpainter.end()
154 |
155 | def is_pos_outside_bounds(self, pos):
156 | x, y = pos.x(), pos.y()
157 | mx, my = self.width(), self.height()
158 | if x < 0 or y < 0: return True
159 | if x >= mx or y >= my: return True
160 | return False
161 |
162 | def resize_to(self):
163 | # figure out what the limiting dimension is
164 | w, h = self.width(), self.height()
165 | w = (int(1.0 / self.height_to_width_ratio() * self.height()))
166 | w = int(w * self.scale_factor)
167 | h = int(h * self.scale_factor)
168 | return (w, h)
169 |
170 | def height_to_width_ratio(self):
171 | if self.pixmap is None:
172 | return self.aspect_ratio
173 | return float(self.pixmap.height()) / self.pixmap.width()
174 |
175 | def sizeHint(self):
176 | base_width = int(self.resize_hint * self.scale_factor)
177 | return QSize(base_width, self.heightForWidth(base_width))
178 |
179 | def resizeEvent(self, event):
180 | if int(self.width() * self.height_to_width_ratio()) != self.height():
181 | self.resize_to_aspect = True
182 | self.resize_to_aspect = True
183 |
184 | def heightForWidth(self, width):
185 | return int(self.height_to_width_ratio() * width)
186 |
187 | def enterEvent(self, event):
188 | self.setMouseTracking(True)
189 | self.setFocus()
190 | for w in self.forward_widgets:
191 | if w.enabled():
192 | w.mouse_enter()
193 | self.update()
194 |
195 | def leaveEvent(self, event):
196 | self.setMouseTracking(False)
197 | for w in self.forward_widgets:
198 | if w.enabled():
199 | w.mouse_leave()
200 | self.update()
201 |
202 | def mouseMoveEvent(self, event):
203 | if self.is_pos_outside_bounds(event.localPos()):
204 | # call this if you dont want to continue drawing when you return into the widget
205 | # self.stop_painting()
206 | return
207 | sx = (event.pos().x() / self.width())
208 | sy = (event.pos().y() / self.height())
209 |
210 | for w in self.forward_widgets:
211 | # if w.enabled():
212 | # Right now, let's propagate mouse events even when disabled. The widget should decide.
213 | # This is a workaround so that if the forwarded widget gets enabled after being disabled,
214 | # it will still have a chance to know where the cursor is currently.
215 | w.mouse_moved(sy, sx)
216 | self.update()
217 |
218 | def wheelEvent(self, event):
219 | for w in self.forward_widgets:
220 | if w.enabled():
221 | w.mouse_zoom(event.angleDelta().y() * 1/8)
222 | self.update()
223 |
224 | def mousePressEvent(self, event):
225 | btn = MouseButton(event.button())
226 |
227 | sx = event.localPos().x() / self.width()
228 | sy = event.localPos().y() / self.height()
229 |
230 | for w in self.forward_widgets:
231 | if w.enabled():
232 | w.mouse_down(sy, sx, btn)
233 | self.update()
234 |
235 | def mouseReleaseEvent(self, event):
236 | sx = event.localPos().x() / self.width()
237 | sy = event.localPos().y() / self.height()
238 | btn = MouseButton(event.button())
239 | for w in self.forward_widgets:
240 | if w.enabled():
241 | w.mouse_up(sy, sx, btn)
242 | self.update()
243 |
244 | def keyPressEvent(self, event):
245 | # This doesn't always work, and some keys are not represented, but overall it's enough
246 | key = event.text()
247 | for w in self.forward_widgets:
248 | if w.enabled():
249 | w.key_down(key)
250 | self.update()
251 |
--------------------------------------------------------------------------------