├── LICENSE ├── README.md ├── _config.yaml ├── align.py ├── app.py ├── components ├── all.pt ├── b1024_conv0.pt ├── b1024_conv1.pt ├── b1024_torgb.pt ├── b128_conv0.pt ├── b128_conv1.pt ├── b128_torgb.pt ├── b16_conv0.pt ├── b16_conv1.pt ├── b16_torgb.pt ├── b256_conv0.pt ├── b256_conv1.pt ├── b256_torgb.pt ├── b32_conv0.pt ├── b32_conv1.pt ├── b32_torgb.pt ├── b4_conv1.pt ├── b4_torgb.pt ├── b512_conv0.pt ├── b512_conv1.pt ├── b512_torgb.pt ├── b64_conv0.pt ├── b64_conv1.pt ├── b64_torgb.pt ├── b8_conv0.pt ├── b8_conv1.pt └── b8_torgb.pt ├── config.py ├── directions ├── age.npy ├── beard.npy ├── empirical_glasses.npy ├── eye_distance.npy ├── eye_eyebrow_distance.npy ├── eye_ratio.npy ├── eyes_open.npy ├── gender.npy ├── light.npy ├── lip_ratio.npy ├── mouth_open.npy ├── mouth_ratio.npy ├── nose_mouth_distance.npy ├── nose_ratio.npy ├── nose_tip.npy ├── pitch.npy ├── roll.npy ├── smile_styleganv2.npy └── yaw.npy ├── dnnlib ├── __init__.py └── util.py ├── extensions ├── canvas_to_masks.cpp └── heatmap.cpp ├── mask.py ├── mask_refinement.py ├── qtutil.py ├── requirements.txt ├── resources.py ├── resources └── iconS.png ├── s_presets.py ├── sample_image └── input.jpg ├── scripts ├── __init__.py ├── extract_components.py └── idempotent_blend.py ├── setup_cpp_ext.py ├── styleclip_mapper.py ├── styleclip_presets.py ├── stylegan_legacy.py ├── stylegan_networks.py ├── stylegan_project.py ├── stylegan_tune.py ├── synthesis.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── w_directions.py └── widgets ├── __init__.py ├── editor.py ├── mask_painter.py └── workspace.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 David Futschik 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chunkmogrify: Real image inversion via Segments 2 | 3 |
4 |

5 | Logo 6 |

7 |
8 | 9 | 10 |

11 | Teaser video with live editing sessions can be found here
12 | 13 | 14 | 15 |

16 | 17 | 18 | 19 | This code demonstrates the ideas discussed in arXiv submission *Real Image Inversion via Segments*. 20 | http://arxiv.org/abs/2110.06269 21 | (David Futschik1, Michal Lukáč2, Eli Shechtman2, Daniel Sýkora1) 22 | 23 | 1Czech Technical University in Prague 24 | 2Adobe Research 25 | 26 | 27 | *Abstract: 28 | We present a simple, yet effective approach to editing 29 | real images via generative adversarial networks (GAN). Unlike previous 30 | techniques, that treat all editing tasks as an operation that affects pixel 31 | values in the entire image in our approach we cut up the image into a set of 32 | smaller segments. For those segments corresponding latent codes of a generative 33 | network can be estimated with greater accuracy due to the lower number of 34 | constraints. When codes are altered by the user the content in the image is 35 | manipulated locally while the rest of it remains unaffected. Thanks to this 36 | property the final edited image better retains the original structures and thus 37 | helps to preserve natural look.* 38 | 39 |
40 |

41 | before 42 | after 43 |

44 |
45 | 46 |
47 |

48 | before 49 | after 50 |

51 |
52 | 53 | ## What do I need? 54 | You will need a local machine with a relatively recent GPU - I wouldn't recommend trying 55 | Chunkmogrify with anything older than RTX 2080. It is technically possible to run even on CPU, 56 | but the operations become so slow that the user experience is not enjoyable. 57 | 58 | ## Quick startup guide 59 | Requirements: 60 | Python 3.7 or newer 61 | 62 | Note: If you are using Anaconda, I recommend creating a new environment to run this project. 63 | Packages installed with conda and pip often don't play together very nicely. 64 | 65 | Steps to be able to successfully run the project: 66 | 1) Clone or download the repository and open a terminal / Powershell instance in the directory. 67 | 2) Install the required python packages by running `pip install -r requirements.txt`. This 68 | might take a while, since it will download a few packages which will be several hundred MBs of data. 69 | Some packages might need to compile their extensions (as well as this project itself), so a C++ 70 | compiler needs to be present. On Linux, this is typically not an issue, but running on Windows might 71 | require Visual Studio and CUDA installations to successfully setup the project. 72 | 3) Run `python app.py`. When running for the first time, it will automatically download required 73 | resources, which are also several hundred megabytes. Progression of the download can be monitored 74 | in the command line window. 75 | 76 | To see if everything installed and configured properly, load up a photo and try running a projection 77 | step. If there are no errors, you are good to go. 78 | 79 | 80 | ### Possible problems: 81 | *Torch not compiled with CUDA enabled.* 82 | Run 83 | ``` 84 | pip uninstall torch 85 | pip cache purge 86 | pip install torch -f https://download.pytorch.org/whl/torch_stable.html 87 | ``` 88 | 89 | ## Explanation of usage 90 | 91 |

92 | Tutorial video: click below
93 | 94 | 95 | 96 |

97 | 98 | Open an image using `File -> Image from File`. There is a sample image provided to check 99 | functionality. 100 | 101 | Mask painting: 102 | Left click paints, right click unpaints. Mouse wheel controls the size of the brush. 103 | 104 | Projection: 105 | Input a number of steps (100 or 200 is ok, 500 is max before LR goes to 0 currently) and press 106 | `Projection Steps`. Wait until projection finishes, you can observe the global image view by choosing 107 | output mode `Projection Only` during this process. To fine-tune, you can perform a small number of 108 | `Pivotal Tuning` steps. 109 | 110 | Editing: 111 | To add an edit, click the double arrow down icon in the Attribute Editor on the left side. Choose 112 | the type of edit (W, S, Styleclip), the direction of the edit, and drag the sliders to change the 113 | currently masked region. Usually it's necessary to increase the `multiplier` before noticeable 114 | changes are reflected via the `direction` slider. 115 | 116 | Multiple different edits can be composed on top of each other at the same time. Their order 117 | is largely irrelevant. Currently in the default mode, only one region is being edited, and so 118 | all selected edits apply to the same region. If you would like to change the region, you can 119 | `Freeze` the current image, and perform a new projection, but you will lose the ability to change 120 | existing edits. 121 | 122 | To save the current image, click the `Save Current Image` button. If the `Unalign` checkbox is 123 | active, the program will attempt to compose the aligned face back into the original image. Saved 124 | images can be found in the `SavedImages` directory by default. This can be changed in `_config.yaml`. 125 | 126 | ## Keyboard shortcuts 127 | 128 | Current keyboard shortcuts include: 129 | 130 | Show/Hide mask :: Alt+M 131 | Toggle mask painting :: Alt+N 132 | 133 | ## W-space editing 134 | Source for some of the basic directions: 135 | (https://twitter.com/robertluxemburg/status/1207087801344372736) 136 | 137 | To add your own directions, save them in a numpy pickle format as a (num_ws, 512) or (1, 512) 138 | format and specify their path in `w_directions.py`. 139 | 140 | ## Style-space editing (S space edits) 141 | Source: 142 | StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation 143 | (https://arxiv.org/abs/2011.12799) 144 | (https://github.com/betterze/StyleSpace) 145 | 146 | The presets can be found in `s_presets.py`, some were taken directly from the paper, others 147 | I found by manual exploration. You can perform similar exploration by choosing the `Custom` 148 | preset once you have a projection. 149 | 150 | ## StyleCLIP editing 151 | Source: 152 | StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery 153 | (https://arxiv.org/abs/2103.17249) 154 | (https://github.com/orpatashnik/StyleCLIP) 155 | 156 | Pretrained models taken from (https://github.com/orpatashnik/StyleCLIP/blob/main/utils.py) and 157 | manually removed the decoder from the state dict, since it's not used and takes up majority of 158 | file size. 159 | 160 | ## PTI Optimization 161 | Source: 162 | Pivotal Tuning for Latent-based Editing of Real Images 163 | (https://arxiv.org/abs/2106.05744) 164 | 165 | This method allows you to match the target photo very closely, while retaining editing capacities. 166 | 167 | It's often good to run 30-50 iterations of PTI to get very close matching of the source image, 168 | which won't cause a very noticeable drop in the editing capabilities. 169 | 170 | 171 | ## Attribution 172 | This repository makes use of code provided by the various repositories linked above, plus 173 | additionally code from: 174 | 175 | styleganv2-ada-pytorch (https://github.com/NVlabs/stylegan2-ada-pytorch) 176 | poisson-image-editing (https://github.com/PPPW/poisson-image-editing) for optional support 177 | of idempotent blend (slow implementation of blending that only changes the masked part which 178 | can be accessed by uncommenting the option in `synthesis.py`) 179 | 180 | ## Citation 181 | 182 | If you find this code useful for your research, please cite the arXiv submission linked above. 183 | -------------------------------------------------------------------------------- /_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Initial w to optimize from. Can be set to ~ for None. 3 | initial_w: ~ 4 | # Skip auto alignment of the images. Only use this if you already have aligned images. 5 | skip_alignment: False 6 | 7 | generator_path: resources/ffhq.pkl 8 | # generator_path: pti_resources/model_merkel.pt 9 | # Perform plain torch.load on the pkl, otherwise look for G_ema. 10 | generator_load_raw: False 11 | # Resolution of the trained generator. 12 | generator_native_resolution: [1024, 1024] 13 | 14 | # Default projection mode. 15 | projection_mode: w_projection 16 | # Projection mode arguments. 17 | projection_args: 18 | lr_init: 1.0e-1 19 | l2_loss_weight: 0 20 | l1_loss_weight: 0. 21 | noise_regularize_weight: 0. # 10000. 22 | mean_latent_loss_weight: 10. 23 | percept_downsample: 0.5 24 | 25 | # Set this to ~ [0.5 - 1] if you want faster projection at the cost of ui updates. 26 | minimum_projection_update_window: 0.1 27 | 28 | # Use this device for torch. 29 | device: cuda:0 30 | # Don't change this unless you want to do multimask descent. 31 | max_segments: 1 32 | # Save exported images here. 33 | export_directory: SavedImages 34 | 35 | # Skip loading some resources for high performance startup. 36 | ui_debug_run: False 37 | show_debug_menu: False 38 | -------------------------------------------------------------------------------- /components/all.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/all.pt -------------------------------------------------------------------------------- /components/b1024_conv0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b1024_conv0.pt -------------------------------------------------------------------------------- /components/b1024_conv1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b1024_conv1.pt -------------------------------------------------------------------------------- /components/b1024_torgb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b1024_torgb.pt -------------------------------------------------------------------------------- /components/b128_conv0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b128_conv0.pt -------------------------------------------------------------------------------- /components/b128_conv1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b128_conv1.pt -------------------------------------------------------------------------------- /components/b128_torgb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b128_torgb.pt -------------------------------------------------------------------------------- /components/b16_conv0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b16_conv0.pt -------------------------------------------------------------------------------- /components/b16_conv1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b16_conv1.pt -------------------------------------------------------------------------------- /components/b16_torgb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b16_torgb.pt -------------------------------------------------------------------------------- /components/b256_conv0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b256_conv0.pt -------------------------------------------------------------------------------- /components/b256_conv1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b256_conv1.pt -------------------------------------------------------------------------------- /components/b256_torgb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b256_torgb.pt -------------------------------------------------------------------------------- /components/b32_conv0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b32_conv0.pt -------------------------------------------------------------------------------- /components/b32_conv1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b32_conv1.pt -------------------------------------------------------------------------------- /components/b32_torgb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b32_torgb.pt -------------------------------------------------------------------------------- /components/b4_conv1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b4_conv1.pt -------------------------------------------------------------------------------- /components/b4_torgb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b4_torgb.pt -------------------------------------------------------------------------------- /components/b512_conv0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b512_conv0.pt -------------------------------------------------------------------------------- /components/b512_conv1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b512_conv1.pt -------------------------------------------------------------------------------- /components/b512_torgb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b512_torgb.pt -------------------------------------------------------------------------------- /components/b64_conv0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b64_conv0.pt -------------------------------------------------------------------------------- /components/b64_conv1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b64_conv1.pt -------------------------------------------------------------------------------- /components/b64_torgb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b64_torgb.pt -------------------------------------------------------------------------------- /components/b8_conv0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b8_conv0.pt -------------------------------------------------------------------------------- /components/b8_conv1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b8_conv1.pt -------------------------------------------------------------------------------- /components/b8_torgb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/components/b8_torgb.pt -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | import yaml 7 | 8 | class dotdict(dict): 9 | __getattr__ = dict.__getitem__ 10 | __setattr__ = dict.__setitem__ 11 | __delattr__ = dict.__delitem__ 12 | 13 | def has_set(self, attr): 14 | return attr in self and self[attr] is not None 15 | 16 | def _into_deep_dotdict(regular_dict): 17 | new_dict = dotdict(regular_dict) 18 | for k, v in regular_dict.items(): 19 | if type(v) == dict: 20 | new_dict[k] = _into_deep_dotdict(v) 21 | return new_dict 22 | 23 | def _load_config(path): 24 | with open(path) as fs: 25 | loaded = yaml.safe_load(fs) 26 | return _into_deep_dotdict(loaded) 27 | 28 | config = None 29 | def global_config(): 30 | global config 31 | if config is None: 32 | config = _load_config("_config.yaml") 33 | return config -------------------------------------------------------------------------------- /directions/age.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/age.npy -------------------------------------------------------------------------------- /directions/beard.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/beard.npy -------------------------------------------------------------------------------- /directions/empirical_glasses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/empirical_glasses.npy -------------------------------------------------------------------------------- /directions/eye_distance.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/eye_distance.npy -------------------------------------------------------------------------------- /directions/eye_eyebrow_distance.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/eye_eyebrow_distance.npy -------------------------------------------------------------------------------- /directions/eye_ratio.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/eye_ratio.npy -------------------------------------------------------------------------------- /directions/eyes_open.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/eyes_open.npy -------------------------------------------------------------------------------- /directions/gender.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/gender.npy -------------------------------------------------------------------------------- /directions/light.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/light.npy -------------------------------------------------------------------------------- /directions/lip_ratio.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/lip_ratio.npy -------------------------------------------------------------------------------- /directions/mouth_open.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/mouth_open.npy -------------------------------------------------------------------------------- /directions/mouth_ratio.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/mouth_ratio.npy -------------------------------------------------------------------------------- /directions/nose_mouth_distance.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/nose_mouth_distance.npy -------------------------------------------------------------------------------- /directions/nose_ratio.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/nose_ratio.npy -------------------------------------------------------------------------------- /directions/nose_tip.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/nose_tip.npy -------------------------------------------------------------------------------- /directions/pitch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/pitch.npy -------------------------------------------------------------------------------- /directions/roll.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/roll.npy -------------------------------------------------------------------------------- /directions/smile_styleganv2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/smile_styleganv2.npy -------------------------------------------------------------------------------- /directions/yaw.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/directions/yaw.npy -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /extensions/canvas_to_masks.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Author: David Futschik 3 | Provided as part of the Chunkmogrify project, 2021. 4 | */ 5 | 6 | #define PY_SSIZE_T_CLEAN 7 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 8 | 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | 22 | static PyObject* canvas_to_masks(PyObject* self, PyObject* args, PyObject* kwargs) { 23 | 24 | const char* kwarg_names[] = { "canvas", "colors", "output_buffer", NULL }; 25 | 26 | PyArrayObject* canvas = nullptr; 27 | PyArrayObject* colors = nullptr; 28 | PyArrayObject* output = nullptr; 29 | if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|$O!", const_cast(kwarg_names), 30 | &PyArray_Type, &canvas, 31 | &PyArray_Type, &colors, 32 | &PyArray_Type, &output)) { 33 | return NULL; 34 | } 35 | 36 | if (PyArray_NDIM(canvas) != 3 and PyArray_NDIM(colors) != 2) { 37 | PyErr_SetString(PyExc_ValueError, "a.ndim must be 3 and b.ndim must be 2."); 38 | return NULL; 39 | } 40 | 41 | int h, w, c, num_color; 42 | h = PyArray_DIM(canvas, 0); 43 | w = PyArray_DIM(canvas, 1); 44 | c = PyArray_DIM(canvas, 2); 45 | num_color = PyArray_DIM(colors, 0); 46 | 47 | int dtype_a, dtype_b; 48 | dtype_a = PyArray_TYPE(canvas); 49 | dtype_b = PyArray_TYPE(colors); 50 | 51 | if (c != PyArray_DIM(colors, 1)) { 52 | PyErr_SetString(PyExc_ValueError, "a.shape[2] != b.shape[1]"); 53 | return NULL; 54 | } 55 | 56 | if (dtype_a != dtype_b or dtype_a != NPY_UINT8) { 57 | PyErr_SetString(PyExc_ValueError, "dtype of both arrays must be uint8."); 58 | return NULL; 59 | } 60 | 61 | bool output_wrong_array = [output, h, w, num_color]() { 62 | if (not output) return true; 63 | if (PyArray_TYPE(output) != NPY_FLOAT) return true; 64 | if (PyArray_NDIM(output) != 3) return true; 65 | if (PyArray_DIM(output, 0) != h) return true; 66 | if (PyArray_DIM(output, 1) != w) return true; 67 | if (PyArray_DIM(output, 2) != num_color) return true; 68 | return false; 69 | }(); 70 | 71 | if (!output or output_wrong_array) { 72 | // alloc new 73 | auto descr = PyArray_DescrFromType(NPY_FLOAT); 74 | npy_intp dims[] = { h, w, num_color }; 75 | output = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type, descr, /* nd */ 3, dims, NULL, NULL, 0, NULL); 76 | // std::cout << "Sideeffect" << std::endl; 77 | } else { 78 | // incref of output 79 | Py_INCREF(output); 80 | } 81 | 82 | #pragma omp parallel for 83 | for (int y = 0; y < h; y++) { 84 | #pragma omp parallel for 85 | for (int x = 0; x < w; x++) { 86 | uint8_t* px = (uint8_t*)PyArray_GETPTR2(canvas, y, x); 87 | int same_idx = 0; 88 | #pragma unroll(12) 89 | for (int k = 0; k < num_color; k++) { 90 | uint8_t* kolor = (uint8_t*)PyArray_GETPTR1(colors, k); 91 | #pragma unroll(4) 92 | for (int it = 0; it < c; it++) { 93 | if (kolor[it] != px[it]) { 94 | goto cont; 95 | } 96 | } 97 | same_idx = k; 98 | goto brk; 99 | cont: ; 100 | } 101 | brk: 102 | float* o = (float*)PyArray_GETPTR2(output, y, x); 103 | for (int it = 0; it < num_color; it++) { 104 | // set the channel of the correct mask 105 | o[it] = it == same_idx ? 1. : 0.; 106 | } 107 | } 108 | } 109 | 110 | return (PyObject*)output; 111 | } 112 | 113 | 114 | static PyMethodDef python_methods[] = { 115 | { "canvas_to_masks", (PyCFunction)canvas_to_masks, METH_VARARGS|METH_KEYWORDS, "convert canvas colors to mask stack" }, 116 | { nullptr, nullptr, 0, nullptr } 117 | }; 118 | 119 | static struct PyModuleDef python_module = { 120 | PyModuleDef_HEAD_INIT, 121 | "_C_canvas", // _C_canvas. 122 | nullptr, // documentation 123 | -1, 124 | python_methods 125 | }; 126 | 127 | PyMODINIT_FUNC 128 | PyInit__C_canvas(void) { 129 | auto x = PyModule_Create(&python_module); 130 | import_array(); 131 | return x; 132 | } -------------------------------------------------------------------------------- /extensions/heatmap.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Author: David Futschik 3 | Provided as part of the Chunkmogrify project, 2021. 4 | */ 5 | 6 | #define PY_SSIZE_T_CLEAN 7 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | 21 | static PyObject* heatmap(PyObject* self, PyObject* args, PyObject* kwargs) { 22 | 23 | const char* kwarg_names[] = { "values", "vmin", "vmax", NULL }; 24 | 25 | PyArrayObject* canvas = nullptr; 26 | double vmin = 0; 27 | double vmax = 0; 28 | 29 | if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!dd|", const_cast(kwarg_names), &PyArray_Type, &canvas, &vmin, &vmax)) { 30 | return NULL; 31 | } 32 | 33 | if (PyArray_NDIM(canvas) != 3 and PyArray_NDIM(canvas) != 2) { 34 | PyErr_SetString(PyExc_ValueError, "a.ndim must be 3 or 2."); 35 | return NULL; 36 | } 37 | 38 | int has_c = PyArray_NDIM(canvas) > 2; 39 | 40 | int h, w, c; 41 | h = PyArray_DIM(canvas, 0); 42 | w = PyArray_DIM(canvas, 1); 43 | if (has_c) { 44 | c = PyArray_DIM(canvas, 2); 45 | if (c != 1) { 46 | PyErr_SetString(PyExc_ValueError, "3rd dimensions must be 1 or None"); 47 | return NULL; 48 | } 49 | } 50 | else { 51 | c = 1; 52 | } 53 | 54 | int dtype_a = PyArray_TYPE(canvas); 55 | 56 | if (dtype_a != NPY_FLOAT) { 57 | PyErr_SetString(PyExc_ValueError, "dtype of array must be float."); 58 | return NULL; 59 | } 60 | 61 | auto descr = PyArray_DescrFromType(NPY_UINT8); 62 | npy_intp dims[] = { h, w, 3 }; 63 | PyArrayObject* output = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type, descr, /* nd */ 3, dims, NULL, NULL, 0, NULL); 64 | 65 | double dv = vmax - vmin; 66 | 67 | #pragma omp parallel for 68 | for (int y = 0; y < h; y++) { 69 | #pragma omp parallel for 70 | for (int x = 0; x < w; x++) { 71 | float* pxval = (float*)PyArray_GETPTR2(canvas, y, x); 72 | float v = *pxval; 73 | float r(1), g(1), b(1); 74 | 75 | if (v < (vmin + .25 * dv)) { 76 | r = 0; 77 | g = 4 * (v - vmin) / dv; 78 | } else if (v < (vmin + .5 * dv)) { 79 | r = 0; 80 | b = 1 + 4 * (vmin + .25 * dv - v) / dv; 81 | } else if (v < (vmin + .75 * dv)) { 82 | r = 4 * (v - vmin - .5 * dv) / dv; 83 | b = 0; 84 | } else { 85 | g = 1 + 4 * (vmin + .75 * dv - v) / dv; 86 | b = 0; 87 | } 88 | 89 | *(uint8_t*)PyArray_GETPTR3(output, y, x, 0) = r * 255; 90 | *(uint8_t*)PyArray_GETPTR3(output, y, x, 1) = g * 255; 91 | *(uint8_t*)PyArray_GETPTR3(output, y, x, 2) = b * 255; 92 | } 93 | } 94 | 95 | return (PyObject*)output; 96 | } 97 | 98 | 99 | static PyMethodDef python_methods[] = { 100 | { "heatmap", (PyCFunction)heatmap, METH_VARARGS|METH_KEYWORDS, "convert data into heatmap" }, 101 | { nullptr, nullptr, 0, nullptr } 102 | }; 103 | 104 | static struct PyModuleDef python_module = { 105 | PyModuleDef_HEAD_INIT, 106 | "_C_heatmap", // _C_canvas. 107 | nullptr, // documentation 108 | -1, 109 | python_methods 110 | }; 111 | 112 | PyMODINIT_FUNC 113 | PyInit__C_heatmap(void) { 114 | auto x = PyModule_Create(&python_module); 115 | import_array(); 116 | return x; 117 | } -------------------------------------------------------------------------------- /mask.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | from typing import Any, Callable, List 7 | 8 | import os 9 | import re 10 | import PIL.Image 11 | import numpy as np 12 | 13 | # MASK 14 | # 1 = Optimize in this region 15 | # 0 = Do not optimize in this region 16 | class MaskState: 17 | def __init__(self, height, width, max_segments, output_fns: List[Callable[[np.ndarray], Any]]): 18 | self.output = output_fns 19 | self.h, self.w, self.c = height, width, max_segments 20 | self.np_buffer = np.zeros((height, width, max_segments), dtype=np.float) 21 | self.update_callbacks = [] 22 | self.mask_version = 0 23 | 24 | def set_to(self, new_buffer: np.ndarray, **cb_kwargs): 25 | assert self.np_buffer.min() >= 0. and self.np_buffer.max() <= 1., \ 26 | f"Mask has range [{self.np_buffer.min()}, {self.np_buffer.max()}]" 27 | buffer_c = new_buffer.shape[2] 28 | self.np_buffer[..., :min(self.c, buffer_c)] = new_buffer[..., :min(self.c, buffer_c)] 29 | self.mask_version += 1 30 | for c in self.update_callbacks: 31 | c(self.np_buffer, **cb_kwargs) 32 | 33 | def load_masks(self, source_dir): 34 | if not os.path.exists(source_dir): 35 | print(f"Could not load masks from {source_dir}") 36 | return 37 | ls = os.listdir(source_dir) 38 | found_idx = [re.match(r'^\d+', x) for x in ls] 39 | found_idx = [int(f.group()) for f in found_idx if f is not None] 40 | # 0 used to be the entire thing, in that case found_idx.remove(0) 41 | for idx in found_idx: 42 | # load the image 43 | m = np.array(PIL.Image.open(os.path.join(source_dir, f'{idx:02d}.png')).convert('L')) 44 | if idx >= self.np_buffer.shape[2]: 45 | print(f"Skipping mask {idx}") 46 | continue 47 | self.np_buffer[..., idx] = m / 255. 48 | self.mask_version += 1 49 | for c in self.update_callbacks: 50 | c(self.np_buffer) 51 | 52 | def save_masks(self, target_dir, painter, max_segments): 53 | # Get it from painter to include the RGB as it is in the app. 54 | rgb_masks, npy_masks = painter.get_volatile_masks() 55 | PIL.Image.fromarray(rgb_masks).save(os.path.join(target_dir, "rgb.png")) 56 | # max_segments + 1 because "empty" is the 0th mask. 57 | PIL.Image.fromarray((npy_masks[:, :, 0] * 255).astype(np.uint8)).save(os.path.join(target_dir, f"all.png")) 58 | for i in range(1, min(max_segments + 1, npy_masks.shape[2])): 59 | PIL.Image.fromarray((npy_masks[:, :, i] * 255).astype(np.uint8)).save(os.path.join(target_dir, f"{i - 1:02d}.png")) 60 | 61 | def get_mask_version(self): 62 | return self.mask_version 63 | 64 | def numpy_buffer(self): 65 | assert self.np_buffer.min() >= 0. and self.np_buffer.max() <= 1., \ 66 | f"Mask has range [{self.np_buffer.min()}, {self.np_buffer.max()}]" 67 | return self.np_buffer 68 | -------------------------------------------------------------------------------- /mask_refinement.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | import math 7 | import torch 8 | import numpy as np 9 | from torch.nn import functional as F 10 | 11 | def torch_grad(x): 12 | a = torch.tensor([[-1, 0, 1]], dtype=torch.float32, device=x.device, requires_grad=False).view((1, 1, 1, 3)) 13 | b = torch.tensor([[-1, 0, 1]], dtype=torch.float32, device=x.device, requires_grad=False).view((1, 1, 3, 1)) 14 | G_x = F.conv2d(x, a) / 2 15 | G_y = F.conv2d(x, b) / 2 16 | G_x = F.pad(G_x, (1, 1, 0, 0), 'constant', 0.) 17 | G_y = F.pad(G_y, (0, 0, 1, 1), 'constant', 0.) 18 | return [G_x, G_y] 19 | 20 | def torch_normgrad_curv(x): 21 | G_x, G_y = torch_grad(x) 22 | G = torch.sqrt(torch.pow(G_x,2) + torch.pow(G_y,2)) 23 | # div 24 | div = torch_grad(G_x / (G + 1e-8) )[0] + torch_grad(G_y / (G + 1e-8))[1] 25 | return G, div 26 | 27 | def contrast_magnify(x, min=0, max=64, fromval=0., toval=255.): 28 | mul_by = (toval - min) / max 29 | x = ((x - min) * mul_by).clip(fromval, toval) 30 | return x 31 | 32 | def mask_refine(mask, image1, image2, dt_A=0.001, dt_B=0.1, iters=300): 33 | 34 | S = (mask - 0.5) 35 | P = S[:, 0:1, :, :] 36 | 37 | Pg = torch_normgrad_curv(P)[0] 38 | 39 | I_1_mean = ((Pg * image1)).sum(dim=(2,3)) / Pg.sum() 40 | I_2_mean = ((Pg * image2)).sum(dim=(2,3)) / Pg.sum() 41 | 42 | assert I_1_mean.shape == (1, 3), "Mean wrong shape" 43 | I_1 = image1 - I_1_mean[:, :, None, None] 44 | I_2 = image2 - I_2_mean[:, :, None, None] 45 | Fn = I_1 - I_2 46 | Fn = Fn.norm(dim=1) / math.sqrt(12) 47 | 48 | P = P * 255 49 | Fn = Fn * 255 50 | 51 | Fn = contrast_magnify(Fn) 52 | 53 | with torch.no_grad(): 54 | for _ in range(iters): 55 | ng, div = torch_normgrad_curv(P) 56 | P -= dt_A * Fn * ng - dt_B * ng * div 57 | 58 | P = torch.where(P < 0, 0., 1.) 59 | 60 | new_mask = P 61 | return new_mask 62 | 63 | if __name__ == "__main__": 64 | import PIL.Image 65 | 66 | i1 = PIL.Image.open("_I1_.png") 67 | i2 = PIL.Image.open("_I2_.png") 68 | s = PIL.Image.open("_P_.png") 69 | 70 | i1 = (torch.tensor( np.array(i1).astype(np.float), dtype=torch.float32, device='cuda:0' ) / 127.5).permute((2,0,1)).unsqueeze(0) 71 | i2 = (torch.tensor( np.array(i2).astype(np.float), dtype=torch.float32, device='cuda:0' ) / 127.5).permute((2,0,1)).unsqueeze(0) 72 | s = (torch.tensor( np.array(s).astype(np.float), dtype=torch.float32, device='cuda:0' ) / 255.).unsqueeze(0).unsqueeze(0) 73 | 74 | m = mask_refine(s, i1 - 1, i2 - 1) 75 | 76 | d = m.cpu().numpy()[0].transpose((1,2,0)).__mul__(255.).astype(np.uint8) 77 | PIL.Image.fromarray(d[:, :, 0]).save("maskrefine.png") 78 | -------------------------------------------------------------------------------- /qtutil.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | import os 7 | import re 8 | import numpy as np 9 | import PIL.Image as Image 10 | from enum import Enum 11 | from threading import Thread 12 | from threading import Timer 13 | from time import time, perf_counter 14 | 15 | from PyQt5.QtWidgets import * 16 | from PyQt5.QtGui import * 17 | from PyQt5.QtCore import * 18 | 19 | class dotdict(dict): 20 | __getattr__ = dict.__getitem__ 21 | __setattr__ = dict.__setitem__ 22 | __delattr__ = dict.__delitem__ 23 | 24 | def has_set(self, attr): 25 | return attr in self and self[attr] is not None 26 | 27 | class MouseButton(Enum): 28 | LEFT = 1 29 | RIGHT = 2 30 | MIDDLE = 4 31 | SIDE_1 = 8 32 | SIDE_2 = 16 33 | 34 | def QHLine(): 35 | w = QFrame() 36 | w.setFrameShape(QFrame.HLine) 37 | w.setFrameShadow(QFrame.Sunken) 38 | return w 39 | 40 | def QVLine(): 41 | w = QFrame() 42 | w.setFrameShape(QFrame.VLine) 43 | w.setFrameShadow(QFrame.Raised) 44 | return w 45 | 46 | def qpixmap_loader(path: str): 47 | if not os.path.exists(path): 48 | print(f"Warning: {path} does not exist, but queued for load as pixmap") 49 | else: 50 | print(f"Loading {path}") 51 | return QPixmap(path) 52 | 53 | def npy_loader(path: str): 54 | if not os.path.exists(path): 55 | raise RuntimeError(f"{path} does not exist") 56 | return np.array(Image.open(path)) 57 | 58 | def make_dirs_if_not_exists(path): 59 | if not os.path.exists(path) or not os.path.isdir(path): 60 | os.makedirs(path) 61 | 62 | def export_image(prefix, npy_image): 63 | if not os.path.exists(prefix): 64 | os.mkdir(prefix) 65 | ls = os.listdir(prefix) 66 | prev_imgs = [re.match(r'^\d+', x) for x in ls] 67 | prev_imgs = [int(x.group()) for x in prev_imgs if x is not None] 68 | cur_id = max(prev_imgs, default=-1) + 1 69 | fname = os.path.join(prefix, f'{cur_id:05d}.png') 70 | Image.fromarray(npy_image).save(fname) 71 | return fname 72 | 73 | def image_save(npy_image, path): 74 | Image.fromarray(npy_image).save(path) 75 | 76 | def qim2np(qimage, swapaxes=True): 77 | # assumes BGRA if swapaxes is True 78 | # assert(qimage.format() == QImage.Format_8888) 79 | bytes_per_pxl = 4 80 | if qimage.format() == QImage.Format_Grayscale8: 81 | bytes_per_pxl = 1 82 | qimage.__array_interface__ = { 83 | 'shape': (qimage.height(), qimage.width(), bytes_per_pxl), 84 | 'typestr': "|u1", 85 | 'data': (int(qimage.bits()), False), 86 | 'version': 3 87 | } 88 | npim = np.asarray(qimage) 89 | # npim is now BGRA, cast it into RGBA 90 | if bytes_per_pxl == 4 and swapaxes: 91 | npim = npim[...,(2,1,0,3)] 92 | return npim 93 | 94 | def np2qim(nparr, fmt_select='auto', do_copy=True): 95 | h, w = nparr.shape[0:2] 96 | # the copy is required unless reference is desired! 97 | if do_copy: 98 | nparr = nparr.astype('uint8').copy() 99 | else: 100 | nparr = np.ascontiguousarray(nparr.astype('uint8')) 101 | if nparr.ndim == 2: 102 | # or Alpha8 when it has no 3rd dimension (masks) 103 | fmt = QImage.Format_Grayscale8 #Alpha8 104 | elif fmt_select != 'auto': 105 | fmt = fmt_select 106 | else: 107 | fmt = {3: QImage.Format_RGB888, 4: QImage.Format_RGBA8888}[nparr.shape[2]] 108 | qim = QImage(nparr, w, h, nparr.strides[0], fmt) # _Premultiplied 109 | return qim 110 | 111 | def destroy_layout(layout): 112 | if layout is not None: 113 | while layout.count(): 114 | item = layout.takeAt(0) 115 | widget = item.widget() 116 | if widget is not None: 117 | widget.deleteLater() 118 | else: 119 | destroy_layout(item.layout()) 120 | 121 | class NotifyWait(QObject): 122 | acquire = pyqtSignal(str) 123 | release = pyqtSignal() 124 | 125 | def __init__(self, parent=None) -> None: 126 | super().__init__(parent=parent) 127 | self.msg = "" 128 | d = QProgressDialog(self.msg, "Cancel", 0, 0, parent=get_default_parent()) 129 | d.setCancelButton(None) 130 | d.setWindowTitle("Working") 131 | d.setWindowFlags(d.windowFlags() & ~Qt.WindowCloseButtonHint) 132 | self.d = d 133 | self.d.cancel() 134 | 135 | def _show(self): 136 | self.d.setLabelText(self.msg) 137 | self.d.exec_() 138 | 139 | def _hide(self): 140 | self.d.done(0) 141 | 142 | 143 | class RaiseError(QObject): 144 | raiseme = pyqtSignal(str, str, bool) # is_fatal 145 | 146 | def __init__(self, parent=None) -> None: 147 | super().__init__(parent=parent) 148 | 149 | def _show(self, where, what, is_fatal): 150 | msg = QMessageBox(self.parent()) 151 | msg.setText(where) 152 | msg.setInformativeText(what) 153 | msg.setWindowTitle("Error") 154 | msg.setIcon(QMessageBox.Critical) 155 | msgresizer = QSpacerItem(500, 0, QSizePolicy.Minimum, QSizePolicy.Expanding) 156 | msg.layout().addItem(msgresizer, msg.layout().rowCount(), 0, 1, msg.layout().columnCount()) 157 | msg.exec_() 158 | if is_fatal: 159 | exit(1) 160 | 161 | _global_parent = None 162 | _notify_wait = None 163 | _global_error = None 164 | def set_default_parent(parent): 165 | global _global_parent 166 | global _notify_wait 167 | global _global_error 168 | _global_parent = parent 169 | _notify_wait = NotifyWait(_global_parent) 170 | _global_error = RaiseError(_global_parent) 171 | 172 | def s(msg): 173 | _notify_wait.msg = msg 174 | _notify_wait._show() 175 | def e(): 176 | _notify_wait._hide() 177 | _notify_wait.acquire.connect(s) 178 | _notify_wait.release.connect(e) 179 | 180 | def e(where, what, is_fatal): 181 | _global_error._show(where, what, is_fatal) 182 | _global_error.raiseme.connect(e) 183 | 184 | def get_default_parent(): 185 | global _global_parent 186 | return _global_parent 187 | def get_notify_wait(): 188 | global _notify_wait 189 | return _notify_wait 190 | def get_global_error(): 191 | global _global_error 192 | return _global_error 193 | 194 | def notify_user_wait(message): 195 | d = QProgressDialog(message, "Cancel", 0, 0, parent=get_default_parent()) 196 | d.setCancelButton(None) 197 | d.setWindowTitle("Working") 198 | d.setWindowFlags(d.windowFlags() & ~Qt.WindowCloseButtonHint) 199 | return d 200 | 201 | 202 | def notify_user_error(where, what, parent=None): 203 | msg = QMessageBox(parent) 204 | msg.setText(where) 205 | msg.setInformativeText(what) 206 | msg.setWindowTitle("Error") 207 | msg.setIcon(QMessageBox.Critical) 208 | msgresizer = QSpacerItem(500, 0, QSizePolicy.Minimum, QSizePolicy.Expanding) 209 | msg.layout().addItem(msgresizer, msg.layout().rowCount(), 0, 1, msg.layout().columnCount()) 210 | msg.exec_() 211 | 212 | # Really only works when called from main thread. 213 | def execute_with_wait(message, fn, *args, **kwargs): 214 | d = notify_user_wait(message) 215 | 216 | def fn_impl(): 217 | fn(*args, **kwargs) 218 | d.done(0) 219 | t = Thread(target=fn_impl) 220 | t.start() 221 | d.exec_() 222 | 223 | class NowOrDelayTimer: 224 | def __init__(self, interval): 225 | self.interval = interval 226 | self.last_ok_time = 0 227 | self.timer = None 228 | 229 | def update(self, do_run): 230 | if self.last_ok_time < time() - self.interval: 231 | self.last_ok_time = time() 232 | do_run() 233 | # cancel pending update 234 | if self.timer: self.timer.cancel() 235 | else: 236 | # cancel last pending update 237 | if self.timer: self.timer.cancel() 238 | def closure(): 239 | self.update(do_run) 240 | self.timer = Timer(self.last_ok_time + self.interval - time(), closure) 241 | self.timer.start() 242 | 243 | class MeasureTime: 244 | def __init__(self, name, disable=False): 245 | self.name = name 246 | self.start = None 247 | self.disable = disable 248 | 249 | def __enter__(self): 250 | if not self.disable: 251 | self.start = perf_counter() 252 | 253 | def __exit__(self, type, value, traceback): 254 | if not self.disable: 255 | end = perf_counter() 256 | print(f'{self.name}: {end-self.start:.3f}') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2021.5.30 2 | charset-normalizer==2.0.6 3 | click==8.0.1 4 | cycler==0.10.0 5 | dlib==19.22.1 6 | idna==3.2 7 | imageio==2.9.0 8 | importlib-metadata==4.8.1 9 | kiwisolver==1.3.2 10 | lpips==0.1.4 11 | matplotlib==3.4.3 12 | networkx==2.6.3 13 | numpy==1.21.2 14 | opencv-python-headless==4.5.3.56 15 | Pillow==8.3.2 16 | pyparsing==2.4.7 17 | PyQt5==5.15.4 18 | PyQt5-Qt5==5.15.2 19 | PyQt5-sip==12.9.0 20 | python-dateutil==2.8.2 21 | PyWavelets==1.1.1 22 | PyYAML==5.4.1 23 | requests==2.26.0 24 | scikit-image==0.18.3 25 | scipy==1.7.1 26 | six==1.16.0 27 | tifffile==2021.8.30 28 | tqdm==4.62.3 29 | typing-extensions==3.10.0.2 30 | urllib3==1.26.7 31 | zipp==3.6.0 32 | ninja 33 | # package location 34 | --find-links https://download.pytorch.org/whl/torch_stable.html 35 | torch==1.9.1+cu102 36 | torchvision==0.10.1 37 | 38 | 39 | -------------------------------------------------------------------------------- /resources.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | 7 | import sys 8 | import os 9 | import requests 10 | 11 | from qtutil import make_dirs_if_not_exists 12 | 13 | resource_list = { 14 | 'ffhq_styleganv2': 'resources/ffhq.pkl', 15 | 'dlib_align_data': 'resources/shape_predictor_68_face_landmarks.dat', 16 | 'styleclip_afro': 'resources/styleclip/afro.pt', 17 | 'styleclip_bobcut':'resources/styleclip/bobcut.pt', 18 | 'styleclip_curly_hair': 'resources/styleclip/curly_hair.pt', 19 | 'styleclip_bowlcut': 'resources/styleclip/bowlcut.pt', 20 | 'styleclip_mohawk': 'resources/styleclip/mohawk.pt', 21 | } 22 | 23 | resource_sources = { 24 | 'ffhq_styleganv2' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/ffhq.pkl', 25 | 'dlib_align_data' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/shape_predictor_68_face_landmarks.dat', 26 | 'styleclip_afro' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/styleclip/afro.pt', 27 | 'styleclip_bobcut' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/styleclip/bobcut.pt', 28 | 'styleclip_curly_hair': 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/styleclip/curly_hair.pt', 29 | 'styleclip_bowlcut' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/styleclip/bowlcut.pt', 30 | 'styleclip_mohawk' : 'https://dcgi.fel.cvut.cz/~futscdav/chunkmogrify/styleclip/mohawk.pt', 31 | } 32 | 33 | def download_url(url, store_at): 34 | dir = os.path.dirname(store_at) 35 | make_dirs_if_not_exists(dir) 36 | 37 | # open in binary mode 38 | with open(store_at, "wb") as file: 39 | # get request 40 | response = requests.get(url, stream=True) 41 | total_size = response.headers.get('content-length') 42 | downloaded = 0 43 | total_size = int(total_size) 44 | for data in response.iter_content(chunk_size=8192): 45 | downloaded += len(data) 46 | file.write(data) 47 | done = int(50 * downloaded / total_size) 48 | sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {100*downloaded/total_size:.02f}%") 49 | sys.stdout.flush() 50 | print() 51 | 52 | def check_and_download_all(): 53 | for resource, storage in resource_list.items(): 54 | if not os.path.exists(storage): 55 | print(f"{resource} not found at {storage}, downloading..") 56 | download_url(resource_sources[resource], storage) 57 | -------------------------------------------------------------------------------- /resources/iconS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/resources/iconS.png -------------------------------------------------------------------------------- /s_presets.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | def limits(): 7 | # Generated from 18 layer stylegan (x1024) 8 | # Was done on the fly previously, but this is easier for UI. 9 | # If the architecture changes, just run any input up to S space and fill in the input sizes. 10 | num_s = [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 256, 128, 128, 128, 64, 64, 64, 32, 32] 11 | return { 12 | 'layer': len(num_s), 13 | 'channel': num_s 14 | } 15 | 16 | def known_presets(): 17 | return { 18 | 'gaze': (9, 409), 19 | 'smile': (6, 259), 20 | 'eyebrows_1': (8, 28), 21 | 'eyebrows_2': (12, 455), 22 | 'eyebrows_3': (6, 179), 23 | 'hair_color': (12, 266), 24 | 'fringe': (6, 285), 25 | 'lipstick': (15, 45), 26 | 'eye_makeup': (12, 414), 27 | 'eye_roll': (14, 239), 28 | 'asian_eyes': (9, 376), 29 | 'gray_hair': (6, 364), 30 | 'eye_size': (12, 110), 31 | 'goatee': (9, 421), 32 | 'fat': (6, 104), 33 | 'gender': (6, 128), 34 | 'chin': (6, 131), 35 | 'double_chin': (6, 144), 36 | 'sideburns': (6, 234), 37 | 'forehead_hair': (6, 322), 38 | 'curly_hair': (6, 364), 39 | 'nose_up_down': (9, 86), 40 | 'eye_wide': (9, 63), 41 | 'gender_2': (9, 6), 42 | 'demon_eyes': (14, 319), 43 | 'sunken_eyes': (14, 380), 44 | 'pupil_1': (14, 414), 45 | 'pupil_2': (14, 419), 46 | } -------------------------------------------------------------------------------- /sample_image/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/sample_image/input.jpg -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/extract_components.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | import os 7 | import torch 8 | from synthesis import init_gan 9 | from qtutil import make_dirs_if_not_exists 10 | 11 | init_gan() 12 | from synthesis import gan 13 | 14 | keys = gan.g.synthesis._modules.keys() 15 | 16 | export_to = 'components' 17 | make_dirs_if_not_exists(export_to) 18 | all_w = [] 19 | for key in keys: 20 | 21 | for subkey in gan.g.synthesis._modules[key]._modules.keys(): 22 | w = gan.g.synthesis._modules[key]._modules[subkey].affine._parameters['weight'] 23 | all_w.append(w) 24 | eigs = torch.svd(w).V.cpu() 25 | torch.save(eigs, os.path.join(export_to, f'{key}_{subkey}.pt')) 26 | 27 | all_w = torch.cat(all_w, dim=0) 28 | eigs = torch.svd(all_w).V.cpu() 29 | torch.save(eigs, os.path.join(export_to, f'all.pt')) 30 | -------------------------------------------------------------------------------- /scripts/idempotent_blend.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import scipy.sparse 4 | from scipy.sparse.linalg import spsolve 5 | 6 | 7 | def laplacian_matrix(n, m): 8 | """Generate the Poisson matrix. 9 | Refer to: 10 | https://en.wikipedia.org/wiki/Discrete_Poisson_equation 11 | Note: it's the transpose of the wiki's matrix 12 | """ 13 | mat_D = scipy.sparse.lil_matrix((m, m)) 14 | mat_D.setdiag(-1, -1) 15 | mat_D.setdiag(4) 16 | mat_D.setdiag(-1, 1) 17 | 18 | mat_A = scipy.sparse.block_diag([mat_D] * n).tolil() 19 | 20 | mat_A.setdiag(-1, 1*m) 21 | mat_A.setdiag(-1, -1*m) 22 | 23 | return mat_A 24 | 25 | def poisson_edit(source, target, mask, offset): 26 | """The poisson blending function. 27 | Refer to: 28 | Perez et. al., "Poisson Image Editing", 2003. 29 | """ 30 | 31 | # Assume: 32 | # target is not smaller than source. 33 | # shape of mask is same as shape of target. 34 | y_max, x_max = target.shape[:-1] 35 | y_min, x_min = 0, 0 36 | 37 | x_range = x_max - x_min 38 | y_range = y_max - y_min 39 | 40 | M = np.float32([[1,0,offset[0]],[0,1,offset[1]]]) 41 | source = cv2.warpAffine(source,M,(x_range,y_range)) 42 | 43 | mask = mask[y_min:y_max, x_min:x_max] 44 | mask[mask != 0] = 1 45 | 46 | mat_A = laplacian_matrix(y_range, x_range) 47 | 48 | # for \Delta g 49 | laplacian = mat_A.tocsc() 50 | 51 | # set the region outside the mask to identity 52 | for y in range(1, y_range - 1): 53 | for x in range(1, x_range - 1): 54 | if mask[y, x] == 0: 55 | k = x + y * x_range 56 | mat_A[k, k] = 1 57 | mat_A[k, k + 1] = 0 58 | mat_A[k, k - 1] = 0 59 | mat_A[k, k + x_range] = 0 60 | mat_A[k, k - x_range] = 0 61 | 62 | mat_A = mat_A.tocsc() 63 | 64 | mask_flat = mask.flatten() 65 | for channel in range(source.shape[2]): 66 | source_flat = source[y_min:y_max, x_min:x_max, channel].flatten() 67 | target_flat = target[y_min:y_max, x_min:x_max, channel].flatten() 68 | 69 | # inside the mask: 70 | # \Delta f = div v = \Delta g 71 | alpha = 1 72 | mat_b = laplacian.dot(source_flat)*alpha 73 | 74 | # outside the mask: 75 | # f = t 76 | mat_b[mask_flat==0] = target_flat[mask_flat==0] 77 | 78 | x = spsolve(mat_A, mat_b) 79 | x = x.reshape((y_range, x_range)) 80 | x[x > 255] = 255 81 | x[x < 0] = 0 82 | x = x.astype('uint8') 83 | 84 | target[y_min:y_max, x_min:x_max, channel] = x 85 | 86 | return target -------------------------------------------------------------------------------- /setup_cpp_ext.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | import platform 7 | import setuptools 8 | import numpy as np 9 | from setuptools import sandbox 10 | 11 | platform_specific_flags = [] 12 | if platform.system() == "Windows": 13 | platform_specific_flags += ["/permissive-", "/Ox", "/std:c++11"] 14 | else: 15 | platform_specific_flags += ["-O3", "--std=c++11"] 16 | 17 | ext_modules = [ 18 | setuptools.Extension('_C_canvas', 19 | sources=['extensions/canvas_to_masks.cpp'], 20 | include_dirs=[np.get_include()], 21 | extra_compile_args=platform_specific_flags, 22 | language='c++'), 23 | setuptools.Extension('_C_heatmap', 24 | sources=['extensions/heatmap.cpp'], 25 | include_dirs=[np.get_include()], 26 | extra_compile_args=platform_specific_flags, 27 | language='c++') 28 | ] 29 | 30 | def checked_build(force=False): 31 | def do_build(): 32 | sandbox.run_setup('setup_cpp_ext.py', ['build_ext', '--inplace']) 33 | try: 34 | import _C_canvas 35 | import _C_heatmap 36 | if force: do_build() 37 | except ImportError: 38 | do_build() 39 | 40 | if __name__ == "__main__": 41 | setuptools.setup( 42 | ext_modules=ext_modules 43 | ) -------------------------------------------------------------------------------- /styleclip_mapper.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import Module 6 | 7 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 8 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 9 | if input.ndim == 3: 10 | return ( 11 | F.leaky_relu( 12 | input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope 13 | ) 14 | * scale 15 | ) 16 | else: 17 | return ( 18 | F.leaky_relu( 19 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope 20 | ) 21 | * scale 22 | ) 23 | 24 | class PixelNorm(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | 28 | def forward(self, input): 29 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 30 | 31 | 32 | class EqualLinear(nn.Module): 33 | def __init__( 34 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 35 | ): 36 | super().__init__() 37 | 38 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 39 | 40 | if bias: 41 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 42 | 43 | else: 44 | self.bias = None 45 | 46 | self.activation = activation 47 | 48 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 49 | self.lr_mul = lr_mul 50 | 51 | def forward(self, input): 52 | if self.activation: 53 | out = F.linear(input, self.weight * self.scale) 54 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 55 | 56 | else: 57 | out = F.linear( 58 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 59 | ) 60 | 61 | return out 62 | 63 | def __repr__(self): 64 | return ( 65 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' 66 | ) 67 | 68 | class Mapper(Module): 69 | 70 | def __init__(self, opts): 71 | super(Mapper, self).__init__() 72 | 73 | self.opts = opts 74 | layers = [PixelNorm()] 75 | 76 | for i in range(4): 77 | layers.append( 78 | EqualLinear( 79 | 512, 512, lr_mul=0.01, activation='fused_lrelu' 80 | ) 81 | ) 82 | 83 | self.mapping = nn.Sequential(*layers) 84 | 85 | 86 | def forward(self, x): 87 | x = self.mapping(x) 88 | return x 89 | 90 | 91 | class SingleMapper(Module): 92 | 93 | def __init__(self, opts): 94 | super(SingleMapper, self).__init__() 95 | 96 | self.opts = opts 97 | 98 | self.mapping = Mapper(opts) 99 | 100 | def forward(self, x): 101 | out = self.mapping(x) 102 | return out 103 | 104 | 105 | class LevelsMapper(Module): 106 | 107 | def __init__(self, opts): 108 | super(LevelsMapper, self).__init__() 109 | 110 | self.opts = opts 111 | 112 | if not opts.no_coarse_mapper: 113 | self.course_mapping = Mapper(opts) 114 | if not opts.no_medium_mapper: 115 | self.medium_mapping = Mapper(opts) 116 | if not opts.no_fine_mapper: 117 | self.fine_mapping = Mapper(opts) 118 | 119 | def forward(self, x): 120 | x_coarse = x[:, :4, :] 121 | x_medium = x[:, 4:8, :] 122 | x_fine = x[:, 8:, :] 123 | 124 | if not self.opts.no_coarse_mapper: 125 | x_coarse = self.course_mapping(x_coarse) 126 | else: 127 | x_coarse = torch.zeros_like(x_coarse) 128 | if not self.opts.no_medium_mapper: 129 | x_medium = self.medium_mapping(x_medium) 130 | else: 131 | x_medium = torch.zeros_like(x_medium) 132 | if not self.opts.no_fine_mapper: 133 | x_fine = self.fine_mapping(x_fine) 134 | else: 135 | x_fine = torch.zeros_like(x_fine) 136 | 137 | 138 | out = torch.cat([x_coarse, x_medium, x_fine], dim=1) 139 | 140 | return out 141 | 142 | def get_keys(d, name): 143 | if 'state_dict' in d: 144 | d = d['state_dict'] 145 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 146 | return d_filt 147 | 148 | 149 | class StyleCLIPMapper(nn.Module): 150 | 151 | def __init__(self, opts): 152 | super().__init__() 153 | self.opts = opts 154 | # Define architecture 155 | self.mapper = self.set_mapper() 156 | # Load weights if needed 157 | self.load_weights() 158 | 159 | def set_mapper(self): 160 | if self.opts.mapper_type == 'SingleMapper': 161 | mapper = SingleMapper(self.opts) 162 | elif self.opts.mapper_type == 'LevelsMapper': 163 | mapper = LevelsMapper(self.opts) 164 | else: 165 | raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type)) 166 | return mapper 167 | 168 | def load_weights(self): 169 | if self.opts.checkpoint_path is not None: 170 | print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path)) 171 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 172 | self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True) 173 | -------------------------------------------------------------------------------- /styleclip_presets.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | from resources import resource_list 7 | 8 | 9 | def pretrained_models(): 10 | return { 11 | 'afro': resource_list['styleclip_afro'], 12 | 'bobcut': resource_list['styleclip_bobcut'], 13 | 'mohawk': resource_list['styleclip_mohawk'], 14 | 'bowlcut': resource_list['styleclip_bowlcut'], 15 | 'curly_hair': resource_list['styleclip_curly_hair'], 16 | } -------------------------------------------------------------------------------- /stylegan_legacy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import click 10 | import pickle 11 | import re 12 | import copy 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | from torch_utils import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def load_network_pkl(f, force_fp16=False): 21 | data = _LegacyUnpickler(f).load() 22 | 23 | # Legacy TensorFlow pickle => convert. 24 | if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): 25 | tf_G, tf_D, tf_Gs = data 26 | G = convert_tf_generator(tf_G) 27 | D = convert_tf_discriminator(tf_D) 28 | G_ema = convert_tf_generator(tf_Gs) 29 | data = dict(G=G, D=D, G_ema=G_ema) 30 | 31 | # Add missing fields. 32 | if 'training_set_kwargs' not in data: 33 | data['training_set_kwargs'] = None 34 | if 'augment_pipe' not in data: 35 | data['augment_pipe'] = None 36 | 37 | # Validate contents. 38 | assert isinstance(data['G'], torch.nn.Module) 39 | assert isinstance(data['D'], torch.nn.Module) 40 | assert isinstance(data['G_ema'], torch.nn.Module) 41 | assert isinstance(data['training_set_kwargs'], (dict, type(None))) 42 | assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) 43 | 44 | # Force FP16. 45 | if force_fp16: 46 | for key in ['G', 'D', 'G_ema']: 47 | old = data[key] 48 | kwargs = copy.deepcopy(old.init_kwargs) 49 | if key.startswith('G'): 50 | kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {})) 51 | kwargs.synthesis_kwargs.num_fp16_res = 4 52 | kwargs.synthesis_kwargs.conv_clamp = 256 53 | if key.startswith('D'): 54 | kwargs.num_fp16_res = 4 55 | kwargs.conv_clamp = 256 56 | if kwargs != old.init_kwargs: 57 | new = type(old)(**kwargs).eval().requires_grad_(False) 58 | misc.copy_params_and_buffers(old, new, require_all=True) 59 | data[key] = new 60 | return data 61 | 62 | #---------------------------------------------------------------------------- 63 | 64 | class _TFNetworkStub(dnnlib.EasyDict): 65 | pass 66 | 67 | class _LegacyUnpickler(pickle.Unpickler): 68 | def find_class(self, module, name): 69 | if module == 'dnnlib.tflib.network' and name == 'Network': 70 | return _TFNetworkStub 71 | return super().find_class(module, name) 72 | 73 | #---------------------------------------------------------------------------- 74 | 75 | def _collect_tf_params(tf_net): 76 | # pylint: disable=protected-access 77 | tf_params = dict() 78 | def recurse(prefix, tf_net): 79 | for name, value in tf_net.variables: 80 | tf_params[prefix + name] = value 81 | for name, comp in tf_net.components.items(): 82 | recurse(prefix + name + '/', comp) 83 | recurse('', tf_net) 84 | return tf_params 85 | 86 | #---------------------------------------------------------------------------- 87 | 88 | def _populate_module_params(module, *patterns): 89 | for name, tensor in misc.named_params_and_buffers(module): 90 | found = False 91 | value = None 92 | for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): 93 | match = re.fullmatch(pattern, name) 94 | if match: 95 | found = True 96 | if value_fn is not None: 97 | value = value_fn(*match.groups()) 98 | break 99 | try: 100 | assert found 101 | if value is not None: 102 | tensor.copy_(torch.from_numpy(np.array(value))) 103 | except: 104 | print(name, list(tensor.shape)) 105 | raise 106 | 107 | #---------------------------------------------------------------------------- 108 | 109 | def convert_tf_generator(tf_G): 110 | if tf_G.version < 4: 111 | raise ValueError('TensorFlow pickle version too low') 112 | 113 | # Collect kwargs. 114 | tf_kwargs = tf_G.static_kwargs 115 | known_kwargs = set() 116 | def kwarg(tf_name, default=None, none=None): 117 | known_kwargs.add(tf_name) 118 | val = tf_kwargs.get(tf_name, default) 119 | return val if val is not None else none 120 | 121 | # Convert kwargs. 122 | kwargs = dnnlib.EasyDict( 123 | z_dim = kwarg('latent_size', 512), 124 | c_dim = kwarg('label_size', 0), 125 | w_dim = kwarg('dlatent_size', 512), 126 | img_resolution = kwarg('resolution', 1024), 127 | img_channels = kwarg('num_channels', 3), 128 | mapping_kwargs = dnnlib.EasyDict( 129 | num_layers = kwarg('mapping_layers', 8), 130 | embed_features = kwarg('label_fmaps', None), 131 | layer_features = kwarg('mapping_fmaps', None), 132 | activation = kwarg('mapping_nonlinearity', 'lrelu'), 133 | lr_multiplier = kwarg('mapping_lrmul', 0.01), 134 | w_avg_beta = kwarg('w_avg_beta', 0.995, none=1), 135 | ), 136 | synthesis_kwargs = dnnlib.EasyDict( 137 | channel_base = kwarg('fmap_base', 16384) * 2, 138 | channel_max = kwarg('fmap_max', 512), 139 | num_fp16_res = kwarg('num_fp16_res', 0), 140 | conv_clamp = kwarg('conv_clamp', None), 141 | architecture = kwarg('architecture', 'skip'), 142 | resample_filter = kwarg('resample_kernel', [1,3,3,1]), 143 | use_noise = kwarg('use_noise', True), 144 | activation = kwarg('nonlinearity', 'lrelu'), 145 | ), 146 | ) 147 | 148 | # Check for unknown kwargs. 149 | kwarg('truncation_psi') 150 | kwarg('truncation_cutoff') 151 | kwarg('style_mixing_prob') 152 | kwarg('structure') 153 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) 154 | if len(unknown_kwargs) > 0: 155 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) 156 | 157 | # Collect params. 158 | tf_params = _collect_tf_params(tf_G) 159 | for name, value in list(tf_params.items()): 160 | match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) 161 | if match: 162 | r = kwargs.img_resolution // (2 ** int(match.group(1))) 163 | tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value 164 | kwargs.synthesis.kwargs.architecture = 'orig' 165 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') 166 | 167 | # Convert params. 168 | from training import networks 169 | G = networks.Generator(**kwargs).eval().requires_grad_(False) 170 | # pylint: disable=unnecessary-lambda 171 | _populate_module_params(G, 172 | r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], 173 | r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), 174 | r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], 175 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), 176 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], 177 | r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], 178 | r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), 179 | r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], 180 | r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], 181 | r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], 182 | r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), 183 | r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, 184 | r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1), 185 | r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], 186 | r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0], 187 | r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], 188 | r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(), 189 | r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, 190 | r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1), 191 | r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], 192 | r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0], 193 | r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], 194 | r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), 195 | r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, 196 | r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1), 197 | r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], 198 | r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), 199 | r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, 200 | r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1), 201 | r'.*\.resample_filter', None, 202 | ) 203 | return G 204 | 205 | #---------------------------------------------------------------------------- 206 | 207 | def convert_tf_discriminator(tf_D): 208 | if tf_D.version < 4: 209 | raise ValueError('TensorFlow pickle version too low') 210 | 211 | # Collect kwargs. 212 | tf_kwargs = tf_D.static_kwargs 213 | known_kwargs = set() 214 | def kwarg(tf_name, default=None): 215 | known_kwargs.add(tf_name) 216 | return tf_kwargs.get(tf_name, default) 217 | 218 | # Convert kwargs. 219 | kwargs = dnnlib.EasyDict( 220 | c_dim = kwarg('label_size', 0), 221 | img_resolution = kwarg('resolution', 1024), 222 | img_channels = kwarg('num_channels', 3), 223 | architecture = kwarg('architecture', 'resnet'), 224 | channel_base = kwarg('fmap_base', 16384) * 2, 225 | channel_max = kwarg('fmap_max', 512), 226 | num_fp16_res = kwarg('num_fp16_res', 0), 227 | conv_clamp = kwarg('conv_clamp', None), 228 | cmap_dim = kwarg('mapping_fmaps', None), 229 | block_kwargs = dnnlib.EasyDict( 230 | activation = kwarg('nonlinearity', 'lrelu'), 231 | resample_filter = kwarg('resample_kernel', [1,3,3,1]), 232 | freeze_layers = kwarg('freeze_layers', 0), 233 | ), 234 | mapping_kwargs = dnnlib.EasyDict( 235 | num_layers = kwarg('mapping_layers', 0), 236 | embed_features = kwarg('mapping_fmaps', None), 237 | layer_features = kwarg('mapping_fmaps', None), 238 | activation = kwarg('nonlinearity', 'lrelu'), 239 | lr_multiplier = kwarg('mapping_lrmul', 0.1), 240 | ), 241 | epilogue_kwargs = dnnlib.EasyDict( 242 | mbstd_group_size = kwarg('mbstd_group_size', None), 243 | mbstd_num_channels = kwarg('mbstd_num_features', 1), 244 | activation = kwarg('nonlinearity', 'lrelu'), 245 | ), 246 | ) 247 | 248 | # Check for unknown kwargs. 249 | kwarg('structure') 250 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) 251 | if len(unknown_kwargs) > 0: 252 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) 253 | 254 | # Collect params. 255 | tf_params = _collect_tf_params(tf_D) 256 | for name, value in list(tf_params.items()): 257 | match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) 258 | if match: 259 | r = kwargs.img_resolution // (2 ** int(match.group(1))) 260 | tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value 261 | kwargs.architecture = 'orig' 262 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') 263 | 264 | # Convert params. 265 | from training import networks 266 | D = networks.Discriminator(**kwargs).eval().requires_grad_(False) 267 | # pylint: disable=unnecessary-lambda 268 | _populate_module_params(D, 269 | r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), 270 | r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'], 271 | r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1), 272 | r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], 273 | r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), 274 | r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(), 275 | r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'], 276 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(), 277 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'], 278 | r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), 279 | r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'], 280 | r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(), 281 | r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'], 282 | r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(), 283 | r'b4\.out\.bias', lambda: tf_params[f'Output/bias'], 284 | r'.*\.resample_filter', None, 285 | ) 286 | return D 287 | 288 | #---------------------------------------------------------------------------- 289 | 290 | @click.command() 291 | @click.option('--source', help='Input pickle', required=True, metavar='PATH') 292 | @click.option('--dest', help='Output pickle', required=True, metavar='PATH') 293 | @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True) 294 | def convert_network_pickle(source, dest, force_fp16): 295 | """Convert legacy network pickle into the native PyTorch format. 296 | 297 | The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. 298 | It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. 299 | 300 | Example: 301 | 302 | \b 303 | python legacy.py \\ 304 | --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ 305 | --dest=stylegan2-cat-config-f.pkl 306 | """ 307 | print(f'Loading "{source}"...') 308 | with dnnlib.util.open_url(source) as f: 309 | data = load_network_pkl(f, force_fp16=force_fp16) 310 | print(f'Saving "{dest}"...') 311 | with open(dest, 'wb') as f: 312 | pickle.dump(data, f) 313 | print('Done.') 314 | 315 | #---------------------------------------------------------------------------- 316 | 317 | if __name__ == "__main__": 318 | convert_network_pickle() # pylint: disable=no-value-for-parameter 319 | 320 | #---------------------------------------------------------------------------- 321 | -------------------------------------------------------------------------------- /stylegan_tune.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | # Straightforward Pivotal Tuning Implementation. 7 | import torch 8 | from lpips import LPIPS 9 | from stylegan_project import dilation 10 | 11 | # Changes the model itself. Make a clone if that is not desired. 12 | # Note that mask does not change dynamically. 13 | class PivotalTuning: 14 | def __init__(self, model, device, pivot_w, target_image, mask=None, alpha=1, 15 | lambda_l2=1, lr=3e-4): 16 | self.model = model 17 | self.device = device 18 | self.w = pivot_w 19 | self.target = target_image 20 | self.mask = mask 21 | 22 | self.alpha = alpha 23 | self.lambda_l2 = lambda_l2 24 | self.initial_lr = lr 25 | self.optimizer = None 26 | self.current_iter = 0 27 | 28 | def step(self): 29 | if self.optimizer is None: 30 | self._init_opt() 31 | 32 | self.optimizer.zero_grad() 33 | current_image = self.model(self.w) 34 | loss = self._calc_loss(current_image) 35 | loss.backward() 36 | self.optimizer.step() 37 | 38 | self.current_iter += 1 39 | return self.model 40 | 41 | def iters_done(self): 42 | return self.current_iter 43 | 44 | def _calc_loss(self, x): 45 | if self.mask is not None: 46 | expanded_mask = dilation(self.mask, torch.ones(7, 7, device=self.device)) 47 | x = x * expanded_mask + self.target * (1 - expanded_mask) 48 | 49 | l2_loss = torch.nn.MSELoss()(x, self.target) 50 | lp_loss = self.lpips_loss(x, self.target) 51 | sphere_loss = 0 # Seems to be disabled in PTI anyway. 52 | 53 | loss = l2_loss + lp_loss 54 | return loss 55 | 56 | def _init_opt(self): 57 | self.model.requires_grad_(True) 58 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.initial_lr) 59 | self.lpips_loss = LPIPS().to(self.device) 60 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to suppress known warnings in torch.jit.trace(). 68 | 69 | class suppress_tracer_warnings(warnings.catch_warnings): 70 | def __enter__(self): 71 | super().__enter__() 72 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning) 73 | return self 74 | 75 | #---------------------------------------------------------------------------- 76 | # Assert that the shape of a tensor matches the given list of integers. 77 | # None indicates that the size of a dimension is allowed to vary. 78 | # Performs symbolic assertion when used in torch.jit.trace(). 79 | 80 | def assert_shape(tensor, ref_shape): 81 | if tensor.ndim != len(ref_shape): 82 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 83 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 84 | if ref_size is None: 85 | pass 86 | elif isinstance(ref_size, torch.Tensor): 87 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 88 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 89 | elif isinstance(size, torch.Tensor): 90 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 91 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 92 | elif size != ref_size: 93 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 94 | 95 | #---------------------------------------------------------------------------- 96 | # Function decorator that calls torch.autograd.profiler.record_function(). 97 | 98 | def profiled_function(fn): 99 | def decorator(*args, **kwargs): 100 | with torch.autograd.profiler.record_function(fn.__name__): 101 | return fn(*args, **kwargs) 102 | decorator.__name__ = fn.__name__ 103 | return decorator 104 | 105 | #---------------------------------------------------------------------------- 106 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 107 | # indefinitely, shuffling items as it goes. 108 | 109 | class InfiniteSampler(torch.utils.data.Sampler): 110 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 111 | assert len(dataset) > 0 112 | assert num_replicas > 0 113 | assert 0 <= rank < num_replicas 114 | assert 0 <= window_size <= 1 115 | super().__init__(dataset) 116 | self.dataset = dataset 117 | self.rank = rank 118 | self.num_replicas = num_replicas 119 | self.shuffle = shuffle 120 | self.seed = seed 121 | self.window_size = window_size 122 | 123 | def __iter__(self): 124 | order = np.arange(len(self.dataset)) 125 | rnd = None 126 | window = 0 127 | if self.shuffle: 128 | rnd = np.random.RandomState(self.seed) 129 | rnd.shuffle(order) 130 | window = int(np.rint(order.size * self.window_size)) 131 | 132 | idx = 0 133 | while True: 134 | i = idx % order.size 135 | if idx % self.num_replicas == self.rank: 136 | yield order[i] 137 | if window >= 2: 138 | j = (i - rnd.randint(window)) % order.size 139 | order[i], order[j] = order[j], order[i] 140 | idx += 1 141 | 142 | #---------------------------------------------------------------------------- 143 | # Utilities for operating with torch.nn.Module parameters and buffers. 144 | 145 | def params_and_buffers(module): 146 | assert isinstance(module, torch.nn.Module) 147 | return list(module.parameters()) + list(module.buffers()) 148 | 149 | def named_params_and_buffers(module): 150 | assert isinstance(module, torch.nn.Module) 151 | return list(module.named_parameters()) + list(module.named_buffers()) 152 | 153 | def copy_params_and_buffers(src_module, dst_module, require_all=False, rand_init_extra_channels=False): 154 | assert isinstance(src_module, torch.nn.Module) 155 | assert isinstance(dst_module, torch.nn.Module) 156 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} 157 | for name, tensor in named_params_and_buffers(dst_module): 158 | assert (name in src_tensors) or (not require_all) 159 | if name in src_tensors: 160 | # Only implement rand_fill_extra_channels for torgb for now, to catch future mistakes. 161 | if 'torgb' in name and rand_init_extra_channels and tensor.shape[0] != src_tensors[name].shape[0]: 162 | print(f"<{name}>: Initializing first {src_tensors[name].shape[0]} channels with existing weights") 163 | # tensor[:src_tensors[name].shape[0], ...].copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 164 | elif 'fromrgb' in name and rand_init_extra_channels and tensor.size() != src_tensors[name].size(): 165 | assert tensor.ndim > 1 and tensor.shape[1] != src_tensors[name].shape[1], (tensor.shape, src_tensors[name].shape) 166 | print(f"<{name}>: Initializing first {src_tensors[name].shape[1]} channels with existing weights") 167 | # tensor[:, :src_tensors[name].shape[1], ...].copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 168 | elif tensor.size() != src_tensors[name].size(): 169 | print(f"<{name}>: Tensors differ in size, not initializing") 170 | print(f"{tensor.shape} vs {src_tensors[name].shape}") 171 | else: 172 | # print(name, tensor.shape, src_tensors[name].shape) 173 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 174 | 175 | #---------------------------------------------------------------------------- 176 | # Context manager for easily enabling/disabling DistributedDataParallel 177 | # synchronization. 178 | 179 | @contextlib.contextmanager 180 | def ddp_sync(module, sync): 181 | assert isinstance(module, torch.nn.Module) 182 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 183 | yield 184 | else: 185 | with module.no_sync(): 186 | yield 187 | 188 | #---------------------------------------------------------------------------- 189 | # Check DistributedDataParallel consistency across processes. 190 | 191 | def check_ddp_consistency(module, ignore_regex=None): 192 | assert isinstance(module, torch.nn.Module) 193 | for name, tensor in named_params_and_buffers(module): 194 | fullname = type(module).__name__ + '.' + name 195 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 196 | continue 197 | tensor = tensor.detach() 198 | other = tensor.clone() 199 | torch.distributed.broadcast(tensor=other, src=0) 200 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname 201 | 202 | #---------------------------------------------------------------------------- 203 | # Print summary table of module hierarchy. 204 | 205 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 206 | assert isinstance(module, torch.nn.Module) 207 | assert not isinstance(module, torch.jit.ScriptModule) 208 | assert isinstance(inputs, (tuple, list)) 209 | 210 | # Register hooks. 211 | entries = [] 212 | nesting = [0] 213 | def pre_hook(_mod, _inputs): 214 | nesting[0] += 1 215 | def post_hook(mod, _inputs, outputs): 216 | nesting[0] -= 1 217 | if nesting[0] <= max_nesting: 218 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 219 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 220 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 221 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 222 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 223 | 224 | # Run module. 225 | outputs = module(*inputs) 226 | for hook in hooks: 227 | hook.remove() 228 | 229 | # Identify unique outputs, parameters, and buffers. 230 | tensors_seen = set() 231 | for e in entries: 232 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 233 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 234 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 235 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 236 | 237 | # Filter out redundant entries. 238 | if skip_redundant: 239 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 240 | 241 | # Construct table. 242 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 243 | rows += [['---'] * len(rows[0])] 244 | param_total = 0 245 | buffer_total = 0 246 | submodule_names = {mod: name for name, mod in module.named_modules()} 247 | for e in entries: 248 | name = '' if e.mod is module else submodule_names[e.mod] 249 | param_size = sum(t.numel() for t in e.unique_params) 250 | buffer_size = sum(t.numel() for t in e.unique_buffers) 251 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] 252 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 253 | rows += [[ 254 | name + (':0' if len(e.outputs) >= 2 else ''), 255 | str(param_size) if param_size else '-', 256 | str(buffer_size) if buffer_size else '-', 257 | (output_shapes + ['-'])[0], 258 | (output_dtypes + ['-'])[0], 259 | ]] 260 | for idx in range(1, len(e.outputs)): 261 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 262 | param_total += param_size 263 | buffer_total += buffer_size 264 | rows += [['---'] * len(rows[0])] 265 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 266 | 267 | # Print table. 268 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 269 | print() 270 | for row in rows: 271 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 272 | print() 273 | return outputs 274 | 275 | #---------------------------------------------------------------------------- 276 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import sys 13 | import warnings 14 | import numpy as np 15 | import torch 16 | import dnnlib 17 | import traceback 18 | 19 | from .. import custom_ops 20 | from .. import misc 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | activation_funcs = { 25 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 26 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 27 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 28 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 29 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 30 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 31 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 32 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 33 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 34 | } 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | _inited = False 39 | _plugin = None 40 | _null_tensor = torch.empty([0]) 41 | 42 | def _init(): 43 | global _inited, _plugin 44 | if not _inited: 45 | _inited = True 46 | sources = ['bias_act.cpp', 'bias_act.cu'] 47 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 48 | try: 49 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 50 | except: 51 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 52 | return _plugin is not None 53 | 54 | #---------------------------------------------------------------------------- 55 | 56 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 57 | r"""Fused bias and activation function. 58 | 59 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 60 | and scales the result by `gain`. Each of the steps is optional. In most cases, 61 | the fused op is considerably more efficient than performing the same calculation 62 | using standard PyTorch ops. It supports first and second order gradients, 63 | but not third order gradients. 64 | 65 | Args: 66 | x: Input activation tensor. Can be of any shape. 67 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 68 | as `x`. The shape must be known, and it must match the dimension of `x` 69 | corresponding to `dim`. 70 | dim: The dimension in `x` corresponding to the elements of `b`. 71 | The value of `dim` is ignored if `b` is not specified. 72 | act: Name of the activation function to evaluate, or `"linear"` to disable. 73 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 74 | See `activation_funcs` for a full list. `None` is not allowed. 75 | alpha: Shape parameter for the activation function, or `None` to use the default. 76 | gain: Scaling factor for the output tensor, or `None` to use default. 77 | See `activation_funcs` for the default scaling of each activation function. 78 | If unsure, consider specifying 1. 79 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 80 | the clamping (default). 81 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 82 | 83 | Returns: 84 | Tensor of the same shape and datatype as `x`. 85 | """ 86 | assert isinstance(x, torch.Tensor) 87 | assert impl in ['ref', 'cuda'] 88 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 89 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 90 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 91 | 92 | #---------------------------------------------------------------------------- 93 | 94 | @misc.profiled_function 95 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 96 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 97 | """ 98 | assert isinstance(x, torch.Tensor) 99 | assert clamp is None or clamp >= 0 100 | spec = activation_funcs[act] 101 | alpha = float(alpha if alpha is not None else spec.def_alpha) 102 | gain = float(gain if gain is not None else spec.def_gain) 103 | clamp = float(clamp if clamp is not None else -1) 104 | 105 | # Add bias. 106 | if b is not None: 107 | assert isinstance(b, torch.Tensor) and b.ndim == 1 108 | assert 0 <= dim < x.ndim 109 | assert b.shape[0] == x.shape[dim] 110 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 111 | 112 | # Evaluate activation function. 113 | alpha = float(alpha) 114 | x = spec.func(x, alpha=alpha) 115 | 116 | # Scale by gain. 117 | gain = float(gain) 118 | if gain != 1: 119 | x = x * gain 120 | 121 | # Clamp. 122 | if clamp >= 0: 123 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 124 | return x 125 | 126 | #---------------------------------------------------------------------------- 127 | 128 | _bias_act_cuda_cache = dict() 129 | 130 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 131 | """Fast CUDA implementation of `bias_act()` using custom ops. 132 | """ 133 | # Parse arguments. 134 | assert clamp is None or clamp >= 0 135 | spec = activation_funcs[act] 136 | alpha = float(alpha if alpha is not None else spec.def_alpha) 137 | gain = float(gain if gain is not None else spec.def_gain) 138 | clamp = float(clamp if clamp is not None else -1) 139 | 140 | # Lookup from cache. 141 | key = (dim, act, alpha, gain, clamp) 142 | if key in _bias_act_cuda_cache: 143 | return _bias_act_cuda_cache[key] 144 | 145 | # Forward op. 146 | class BiasActCuda(torch.autograd.Function): 147 | @staticmethod 148 | def forward(ctx, x, b): # pylint: disable=arguments-differ 149 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 150 | x = x.contiguous(memory_format=ctx.memory_format) 151 | b = b.contiguous() if b is not None else _null_tensor 152 | y = x 153 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 154 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 155 | ctx.save_for_backward( 156 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 157 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 158 | y if 'y' in spec.ref else _null_tensor) 159 | return y 160 | 161 | @staticmethod 162 | def backward(ctx, dy): # pylint: disable=arguments-differ 163 | dy = dy.contiguous(memory_format=ctx.memory_format) 164 | x, b, y = ctx.saved_tensors 165 | dx = None 166 | db = None 167 | 168 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 169 | dx = dy 170 | if act != 'linear' or gain != 1 or clamp >= 0: 171 | dx = BiasActCudaGrad.apply(dy, x, b, y) 172 | 173 | if ctx.needs_input_grad[1]: 174 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 175 | 176 | return dx, db 177 | 178 | # Backward op. 179 | class BiasActCudaGrad(torch.autograd.Function): 180 | @staticmethod 181 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 182 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 183 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 184 | ctx.save_for_backward( 185 | dy if spec.has_2nd_grad else _null_tensor, 186 | x, b, y) 187 | return dx 188 | 189 | @staticmethod 190 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 191 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 192 | dy, x, b, y = ctx.saved_tensors 193 | d_dy = None 194 | d_x = None 195 | d_b = None 196 | d_y = None 197 | 198 | if ctx.needs_input_grad[0]: 199 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 200 | 201 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 202 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 203 | 204 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 205 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 206 | 207 | return d_dy, d_x, d_b, d_y 208 | 209 | # Add to cache. 210 | _bias_act_cuda_cache[key] = BiasActCuda 211 | return BiasActCuda 212 | 213 | #---------------------------------------------------------------------------- 214 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import warnings 13 | import contextlib 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | 25 | @contextlib.contextmanager 26 | def no_weight_gradients(): 27 | global weight_gradients_disabled 28 | old = weight_gradients_disabled 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.']): 54 | return True 55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 56 | return False 57 | 58 | def _tuple_of_ints(xs, ndim): 59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 60 | assert len(xs) == ndim 61 | assert all(isinstance(x, int) for x in xs) 62 | return xs 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | _conv2d_gradfix_cache = dict() 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | if not transpose: 112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 113 | else: # transpose 114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 115 | ctx.save_for_backward(input, weight) 116 | return output 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | input, weight = ctx.saved_tensors 121 | grad_input = None 122 | grad_weight = None 123 | grad_bias = None 124 | 125 | if ctx.needs_input_grad[0]: 126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 128 | assert grad_input.shape == input.shape 129 | 130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 131 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 132 | assert grad_weight.shape == weight_shape 133 | 134 | if ctx.needs_input_grad[2]: 135 | grad_bias = grad_output.sum([0, 2, 3]) 136 | 137 | return grad_input, grad_weight, grad_bias 138 | 139 | # Gradient with respect to the weights. 140 | class Conv2dGradWeight(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, grad_output, input): 143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 146 | assert grad_weight.shape == weight_shape 147 | ctx.save_for_backward(grad_output, input) 148 | return grad_weight 149 | 150 | @staticmethod 151 | def backward(ctx, grad2_grad_weight): 152 | grad_output, input = ctx.saved_tensors 153 | grad2_grad_output = None 154 | grad2_input = None 155 | 156 | if ctx.needs_input_grad[0]: 157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 158 | assert grad2_grad_output.shape == grad_output.shape 159 | 160 | if ctx.needs_input_grad[1]: 161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 163 | assert grad2_input.shape == input.shape 164 | 165 | return grad2_grad_output, grad2_input 166 | 167 | _conv2d_gradfix_cache[key] = Conv2d 168 | return Conv2d 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | w = w.flip([2, 3]) 37 | 38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 42 | if out_channels <= 4 and groups == 1: 43 | in_shape = x.shape 44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 46 | else: 47 | x = x.to(memory_format=torch.contiguous_format) 48 | w = w.to(memory_format=torch.contiguous_format) 49 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 50 | return x.to(memory_format=torch.channels_last) 51 | 52 | # Otherwise => execute using conv2d_gradfix. 53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 54 | return op(x, w, stride=stride, padding=padding, groups=groups) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | @misc.profiled_function 59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 60 | r"""2D convolution with optional up/downsampling. 61 | 62 | Padding is performed only once at the beginning, not between the operations. 63 | 64 | Args: 65 | x: Input tensor of shape 66 | `[batch_size, in_channels, in_height, in_width]`. 67 | w: Weight tensor of shape 68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 70 | calling upfirdn2d.setup_filter(). None = identity (default). 71 | up: Integer upsampling factor (default: 1). 72 | down: Integer downsampling factor (default: 1). 73 | padding: Padding with respect to the upsampled image. Can be a single number 74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 75 | (default: 0). 76 | groups: Split input channels into N groups (default: 1). 77 | flip_weight: False = convolution, True = correlation (default: True). 78 | flip_filter: False = convolution, True = correlation (default: False). 79 | 80 | Returns: 81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 82 | """ 83 | # Validate arguments. 84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 87 | assert isinstance(up, int) and (up >= 1) 88 | assert isinstance(down, int) and (down >= 1) 89 | assert isinstance(groups, int) and (groups >= 1) 90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 91 | fw, fh = _get_filter_size(f) 92 | px0, px1, py0, py1 = _parse_padding(padding) 93 | 94 | # Adjust padding to account for up/downsampling. 95 | if up > 1: 96 | px0 += (fw + up - 1) // 2 97 | px1 += (fw - up) // 2 98 | py0 += (fh + up - 1) // 2 99 | py1 += (fh - up) // 2 100 | if down > 1: 101 | px0 += (fw - down + 1) // 2 102 | px1 += (fw - down) // 2 103 | py0 += (fh - down + 1) // 2 104 | py1 += (fh - down) // 2 105 | 106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 107 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 110 | return x 111 | 112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 113 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 116 | return x 117 | 118 | # Fast path: downsampling only => use strided convolution. 119 | if down > 1 and up == 1: 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 122 | return x 123 | 124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 125 | if up > 1: 126 | if groups == 1: 127 | w = w.transpose(0, 1) 128 | else: 129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 130 | w = w.transpose(1, 2) 131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 132 | px0 -= kw - 1 133 | px1 -= kw - up 134 | py0 -= kh - 1 135 | py1 -= kh - up 136 | pxt = max(min(-px0, -px1), 0) 137 | pyt = max(min(-py0, -py1), 0) 138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 140 | if down > 1: 141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 142 | return x 143 | 144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 145 | if up == 1 and down == 1: 146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 148 | 149 | # Fallback: Generic reference implementation. 150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 152 | if down > 1: 153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 154 | return x 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def grid_sample(input, grid): 28 | if _should_use_custom_op(): 29 | return _GridSample2dForward.apply(input, grid) 30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _should_use_custom_op(): 35 | if not enabled: 36 | return False 37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.']): 38 | return True 39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 40 | return False 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | class _GridSample2dForward(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input, grid): 47 | assert input.ndim == 4 48 | assert grid.ndim == 4 49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 50 | ctx.save_for_backward(input, grid) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, grid = ctx.saved_tensors 56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 57 | return grad_input, grad_grid 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class _GridSample2dBackward(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, grad_output, input, grid): 64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient resampling of 2D images.""" 10 | 11 | import os 12 | import sys 13 | import warnings 14 | import numpy as np 15 | import torch 16 | import traceback 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | from . import conv2d_gradfix 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | _inited = False 25 | _plugin = None 26 | _init_already_failed = False 27 | 28 | def _init(): 29 | global _inited, _plugin, _init_already_failed 30 | if _init_already_failed: 31 | return False 32 | if not _inited: 33 | sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] 34 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 35 | try: 36 | _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 37 | except: 38 | _init_already_failed = True 39 | warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 40 | return _plugin is not None 41 | 42 | def _parse_scaling(scaling): 43 | if isinstance(scaling, int): 44 | scaling = [scaling, scaling] 45 | assert isinstance(scaling, (list, tuple)) 46 | assert all(isinstance(x, int) for x in scaling) 47 | sx, sy = scaling 48 | assert sx >= 1 and sy >= 1 49 | return sx, sy 50 | 51 | def _parse_padding(padding): 52 | if isinstance(padding, int): 53 | padding = [padding, padding] 54 | assert isinstance(padding, (list, tuple)) 55 | assert all(isinstance(x, int) for x in padding) 56 | if len(padding) == 2: 57 | padx, pady = padding 58 | padding = [padx, padx, pady, pady] 59 | padx0, padx1, pady0, pady1 = padding 60 | return padx0, padx1, pady0, pady1 61 | 62 | def _get_filter_size(f): 63 | if f is None: 64 | return 1, 1 65 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 66 | fw = f.shape[-1] 67 | fh = f.shape[0] 68 | with misc.suppress_tracer_warnings(): 69 | fw = int(fw) 70 | fh = int(fh) 71 | misc.assert_shape(f, [fh, fw][:f.ndim]) 72 | assert fw >= 1 and fh >= 1 73 | return fw, fh 74 | 75 | #---------------------------------------------------------------------------- 76 | 77 | def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): 78 | r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. 79 | 80 | Args: 81 | f: Torch tensor, numpy array, or python list of the shape 82 | `[filter_height, filter_width]` (non-separable), 83 | `[filter_taps]` (separable), 84 | `[]` (impulse), or 85 | `None` (identity). 86 | device: Result device (default: cpu). 87 | normalize: Normalize the filter so that it retains the magnitude 88 | for constant input signal (DC)? (default: True). 89 | flip_filter: Flip the filter? (default: False). 90 | gain: Overall scaling factor for signal magnitude (default: 1). 91 | separable: Return a separable filter? (default: select automatically). 92 | 93 | Returns: 94 | Float32 tensor of the shape 95 | `[filter_height, filter_width]` (non-separable) or 96 | `[filter_taps]` (separable). 97 | """ 98 | # Validate. 99 | if f is None: 100 | f = 1 101 | f = torch.as_tensor(f, dtype=torch.float32) 102 | assert f.ndim in [0, 1, 2] 103 | assert f.numel() > 0 104 | if f.ndim == 0: 105 | f = f[np.newaxis] 106 | 107 | # Separable? 108 | if separable is None: 109 | separable = (f.ndim == 1 and f.numel() >= 8) 110 | if f.ndim == 1 and not separable: 111 | f = f.ger(f) 112 | assert f.ndim == (1 if separable else 2) 113 | 114 | # Apply normalize, flip, gain, and device. 115 | if normalize: 116 | f /= f.sum() 117 | if flip_filter: 118 | f = f.flip(list(range(f.ndim))) 119 | f = f * (gain ** (f.ndim / 2)) 120 | f = f.to(device=device) 121 | return f 122 | 123 | #---------------------------------------------------------------------------- 124 | 125 | def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): 126 | r"""Pad, upsample, filter, and downsample a batch of 2D images. 127 | 128 | Performs the following sequence of operations for each channel: 129 | 130 | 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). 131 | 132 | 2. Pad the image with the specified number of zeros on each side (`padding`). 133 | Negative padding corresponds to cropping the image. 134 | 135 | 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it 136 | so that the footprint of all output pixels lies within the input image. 137 | 138 | 4. Downsample the image by keeping every Nth pixel (`down`). 139 | 140 | This sequence of operations bears close resemblance to scipy.signal.upfirdn(). 141 | The fused op is considerably more efficient than performing the same calculation 142 | using standard PyTorch ops. It supports gradients of arbitrary order. 143 | 144 | Args: 145 | x: Float32/float64/float16 input tensor of the shape 146 | `[batch_size, num_channels, in_height, in_width]`. 147 | f: Float32 FIR filter of the shape 148 | `[filter_height, filter_width]` (non-separable), 149 | `[filter_taps]` (separable), or 150 | `None` (identity). 151 | up: Integer upsampling factor. Can be a single int or a list/tuple 152 | `[x, y]` (default: 1). 153 | down: Integer downsampling factor. Can be a single int or a list/tuple 154 | `[x, y]` (default: 1). 155 | padding: Padding with respect to the upsampled image. Can be a single number 156 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 157 | (default: 0). 158 | flip_filter: False = convolution, True = correlation (default: False). 159 | gain: Overall scaling factor for signal magnitude (default: 1). 160 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 161 | 162 | Returns: 163 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 164 | """ 165 | assert isinstance(x, torch.Tensor) 166 | assert impl in ['ref', 'cuda'] 167 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 168 | return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) 169 | return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) 170 | 171 | #---------------------------------------------------------------------------- 172 | 173 | @misc.profiled_function 174 | def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): 175 | """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. 176 | """ 177 | # Validate arguments. 178 | assert isinstance(x, torch.Tensor) and x.ndim == 4 179 | if f is None: 180 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 181 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 182 | assert f.dtype == torch.float32 and not f.requires_grad 183 | batch_size, num_channels, in_height, in_width = x.shape 184 | upx, upy = _parse_scaling(up) 185 | downx, downy = _parse_scaling(down) 186 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 187 | 188 | # Upsample by inserting zeros. 189 | x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) 190 | x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) 191 | x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) 192 | 193 | # Pad or crop. 194 | x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) 195 | x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] 196 | 197 | # Setup filter. 198 | f = f * (gain ** (f.ndim / 2)) 199 | f = f.to(x.dtype) 200 | if not flip_filter: 201 | f = f.flip(list(range(f.ndim))) 202 | 203 | # Convolve with the filter. 204 | f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) 205 | if f.ndim == 4: 206 | x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) 207 | else: 208 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) 209 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) 210 | 211 | # Downsample by throwing away pixels. 212 | x = x[:, :, ::downy, ::downx] 213 | return x 214 | 215 | #---------------------------------------------------------------------------- 216 | 217 | _upfirdn2d_cuda_cache = dict() 218 | 219 | def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): 220 | """Fast CUDA implementation of `upfirdn2d()` using custom ops. 221 | """ 222 | # Parse arguments. 223 | upx, upy = _parse_scaling(up) 224 | downx, downy = _parse_scaling(down) 225 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 226 | 227 | # Lookup from cache. 228 | key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) 229 | if key in _upfirdn2d_cuda_cache: 230 | return _upfirdn2d_cuda_cache[key] 231 | 232 | # Forward op. 233 | class Upfirdn2dCuda(torch.autograd.Function): 234 | @staticmethod 235 | def forward(ctx, x, f): # pylint: disable=arguments-differ 236 | assert isinstance(x, torch.Tensor) and x.ndim == 4 237 | if f is None: 238 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 239 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 240 | y = x 241 | if f.ndim == 2: 242 | y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) 243 | else: 244 | y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) 245 | y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) 246 | ctx.save_for_backward(f) 247 | ctx.x_shape = x.shape 248 | return y 249 | 250 | @staticmethod 251 | def backward(ctx, dy): # pylint: disable=arguments-differ 252 | f, = ctx.saved_tensors 253 | _, _, ih, iw = ctx.x_shape 254 | _, _, oh, ow = dy.shape 255 | fw, fh = _get_filter_size(f) 256 | p = [ 257 | fw - padx0 - 1, 258 | iw * upx - ow * downx + padx0 - upx + 1, 259 | fh - pady0 - 1, 260 | ih * upy - oh * downy + pady0 - upy + 1, 261 | ] 262 | dx = None 263 | df = None 264 | 265 | if ctx.needs_input_grad[0]: 266 | dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) 267 | 268 | assert not ctx.needs_input_grad[1] 269 | return dx, df 270 | 271 | # Add to cache. 272 | _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda 273 | return Upfirdn2dCuda 274 | 275 | #---------------------------------------------------------------------------- 276 | 277 | def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): 278 | r"""Filter a batch of 2D images using the given 2D FIR filter. 279 | 280 | By default, the result is padded so that its shape matches the input. 281 | User-specified padding is applied on top of that, with negative values 282 | indicating cropping. Pixels outside the image are assumed to be zero. 283 | 284 | Args: 285 | x: Float32/float64/float16 input tensor of the shape 286 | `[batch_size, num_channels, in_height, in_width]`. 287 | f: Float32 FIR filter of the shape 288 | `[filter_height, filter_width]` (non-separable), 289 | `[filter_taps]` (separable), or 290 | `None` (identity). 291 | padding: Padding with respect to the output. Can be a single number or a 292 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 293 | (default: 0). 294 | flip_filter: False = convolution, True = correlation (default: False). 295 | gain: Overall scaling factor for signal magnitude (default: 1). 296 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 297 | 298 | Returns: 299 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 300 | """ 301 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 302 | fw, fh = _get_filter_size(f) 303 | p = [ 304 | padx0 + fw // 2, 305 | padx1 + (fw - 1) // 2, 306 | pady0 + fh // 2, 307 | pady1 + (fh - 1) // 2, 308 | ] 309 | return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) 310 | 311 | #---------------------------------------------------------------------------- 312 | 313 | def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): 314 | r"""Upsample a batch of 2D images using the given 2D FIR filter. 315 | 316 | By default, the result is padded so that its shape is a multiple of the input. 317 | User-specified padding is applied on top of that, with negative values 318 | indicating cropping. Pixels outside the image are assumed to be zero. 319 | 320 | Args: 321 | x: Float32/float64/float16 input tensor of the shape 322 | `[batch_size, num_channels, in_height, in_width]`. 323 | f: Float32 FIR filter of the shape 324 | `[filter_height, filter_width]` (non-separable), 325 | `[filter_taps]` (separable), or 326 | `None` (identity). 327 | up: Integer upsampling factor. Can be a single int or a list/tuple 328 | `[x, y]` (default: 1). 329 | padding: Padding with respect to the output. Can be a single number or a 330 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 331 | (default: 0). 332 | flip_filter: False = convolution, True = correlation (default: False). 333 | gain: Overall scaling factor for signal magnitude (default: 1). 334 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 335 | 336 | Returns: 337 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 338 | """ 339 | upx, upy = _parse_scaling(up) 340 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 341 | fw, fh = _get_filter_size(f) 342 | p = [ 343 | padx0 + (fw + upx - 1) // 2, 344 | padx1 + (fw - upx) // 2, 345 | pady0 + (fh + upy - 1) // 2, 346 | pady1 + (fh - upy) // 2, 347 | ] 348 | return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) 349 | 350 | #---------------------------------------------------------------------------- 351 | 352 | def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): 353 | r"""Downsample a batch of 2D images using the given 2D FIR filter. 354 | 355 | By default, the result is padded so that its shape is a fraction of the input. 356 | User-specified padding is applied on top of that, with negative values 357 | indicating cropping. Pixels outside the image are assumed to be zero. 358 | 359 | Args: 360 | x: Float32/float64/float16 input tensor of the shape 361 | `[batch_size, num_channels, in_height, in_width]`. 362 | f: Float32 FIR filter of the shape 363 | `[filter_height, filter_width]` (non-separable), 364 | `[filter_taps]` (separable), or 365 | `None` (identity). 366 | down: Integer downsampling factor. Can be a single int or a list/tuple 367 | `[x, y]` (default: 1). 368 | padding: Padding with respect to the input. Can be a single number or a 369 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 370 | (default: 0). 371 | flip_filter: False = convolution, True = correlation (default: False). 372 | gain: Overall scaling factor for signal magnitude (default: 1). 373 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 374 | 375 | Returns: 376 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 377 | """ 378 | downx, downy = _parse_scaling(down) 379 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 380 | fw, fh = _get_filter_size(f) 381 | p = [ 382 | padx0 + (fw - downx + 1) // 2, 383 | padx1 + (fw - downx) // 2, 384 | pady0 + (fh - downy + 1) // 2, 385 | pady1 + (fh - downy) // 2, 386 | ] 387 | return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) 388 | 389 | #---------------------------------------------------------------------------- 390 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /w_directions.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | import torch 7 | 8 | # The second number is the W dimension up to which the change applies. 8 is a good default. 9 | # (Lower number means only low level features will be affected) 10 | # The expected format is a numpy array of the shape (18, 512). 11 | 12 | 13 | # Synthetic because they are just found by mucking around in Z space. 14 | synthetic_glasses = torch.zeros((1, 512, )) 15 | synthetic_glasses[:, 30] = -1 16 | synthetic_glasses[:, 36] = -1 17 | 18 | # Empirical glasses I manually labelled about 200 images and took the difference in means 19 | # as direction. 20 | 21 | # Beard was found through gradient descent modification and subsequent PCA of differences. 22 | 23 | def known_directions(): 24 | return { 25 | 'smile': ('directions/smile_styleganv2.npy', 0, 5), 26 | 'age': ('directions/age.npy', 0, 8), 27 | 'gender': ('directions/gender.npy', 0, 5), 28 | 'eyes_open': ('directions/eyes_open.npy', 0, 12), 29 | 'nose_ratio': ('directions/nose_ratio.npy', 0, 8), 30 | 'nose_tip': ('directions/nose_tip.npy', 0, 12), 31 | 'lip_ratio': ('directions/lip_ratio.npy', 0, 12), 32 | 'eye_distance': ('directions/eye_distance.npy', 0, 8), 33 | 'eye_to_eyebrow_distance': ('directions/eye_eyebrow_distance.npy', 0, 5), 34 | 'eye_ratio': ('directions/eye_ratio.npy', 0, 8), 35 | 'mouth_open': ('directions/mouth_open.npy', 0, 8), 36 | 'pitch': ('directions/pitch.npy', 0, 5), 37 | # 'roll': ('directions/roll.npy', 5), 38 | 'synthetic_glasses': (synthetic_glasses, 2, 4), 39 | 'empirical_glasses': ('directions/empirical_glasses.npy', 2, 4), 40 | 'beard': ('directions/beard.npy', 6, 9), 41 | 'yaw': ('directions/yaw.npy', 0, 5), 42 | 'light': ('directions/light.npy', 5, 11), 43 | 'eye_color': ('directions/light.npy', 14, 15), 44 | # 'eig_all_0': (torch.load('components/all.pt')[:, 0].view(1, 512), 0, 18), 45 | # 'eig_all_1': (torch.load('components/all.pt')[:, 1].view(1, 512), 0, 18), 46 | # 'eig_all_2': (torch.load('components/all.pt')[:, 2].view(1, 512), 0, 18), 47 | # 'eig_all_3': (torch.load('components/all.pt')[:, 3].view(1, 512), 0, 18), 48 | # 'eig_all_4': (torch.load('components/all.pt')[:, 4].view(1, 512), 0, 18), 49 | # 'eig_all_5': (torch.load('components/all.pt')[:, 5].view(1, 512), 0, 18), 50 | # 'eig_all_6': (torch.load('components/all.pt')[:, 6].view(1, 512), 0, 18), 51 | # 'eig_all_7': (torch.load('components/all.pt')[:, 7].view(1, 512), 0, 18), 52 | # 'eig_all_8': (torch.load('components/all.pt')[:, 8].view(1, 512), 0, 18), 53 | # 'eig_all_9': (torch.load('components/all.pt')[:, 9].view(1, 512), 0, 18), 54 | # 'eig_all_10': (torch.load('components/all.pt')[:, 10].view(1, 512), 0, 18), 55 | 56 | # 'eig_64_0': (torch.load('components/b64_conv0.pt')[:, 0].view(1, 512), 0, 18), 57 | # 'eig_64_1': (torch.load('components/b64_conv0.pt')[:, 1].view(1, 512), 0, 18), 58 | # 'eig_64_2': (torch.load('components/b64_conv0.pt')[:, 2].view(1, 512), 0, 18), 59 | # 'eig_64_3': (torch.load('components/b64_conv0.pt')[:, 3].view(1, 512), 0, 18), 60 | # 'eig_64_4': (torch.load('components/b64_conv0.pt')[:, 4].view(1, 512), 0, 18), 61 | # 'eig_64_5': (torch.load('components/b64_conv0.pt')[:, 5].view(1, 512), 0, 18), 62 | # 'eig_64_6': (torch.load('components/b64_conv0.pt')[:, 6].view(1, 512), 0, 18), 63 | # 'eig_64_7': (torch.load('components/b64_conv0.pt')[:, 7].view(1, 512), 0, 18), 64 | # 'eig_64_8': (torch.load('components/b64_conv0.pt')[:, 8].view(1, 512), 0, 18), 65 | # 'eig_64_9': (torch.load('components/b64_conv0.pt')[:, 9].view(1, 512), 0, 18), 66 | # 'eig_64_10': (torch.load('components/b64_conv0.pt')[:, 10].view(1, 512), 0, 18), 67 | 68 | # 'eig_16_0': (torch.load('components/b16_conv0.pt')[:, 0].view(1, 512), 0, 18), # Hair ? 69 | # 'eig_16_1': (torch.load('components/b16_conv0.pt')[:, 1].view(1, 512), 0, 18), # Gender ? 70 | # 'eig_16_2': (torch.load('components/b16_conv0.pt')[:, 2].view(1, 512), 0, 18), 71 | # 'eig_16_3': (torch.load('components/b16_conv0.pt')[:, 3].view(1, 512), 0, 18), # Roll 72 | # 'eig_16_4': (torch.load('components/b16_conv0.pt')[:, 4].view(1, 512), 0, 11), # Age 73 | # 'eig_16_5': (torch.load('components/b16_conv0.pt')[:, 5].view(1, 512), 0, 18), 74 | # 'eig_16_6': (torch.load('components/b16_conv0.pt')[:, 6].view(1, 512), 0, 18), 75 | # 'eig_16_7': (torch.load('components/b16_conv0.pt')[:, 7].view(1, 512), 0, 18), 76 | # 'eig_16_8': (torch.load('components/b16_conv0.pt')[:, 8].view(1, 512), 0, 18), 77 | # 'eig_16_9': (torch.load('components/b16_conv0.pt')[:, 9].view(1, 512), 0, 18), # Weird age 78 | # 'eig_16_10': (torch.load('components/b16_conv0.pt')[:, 10].view(1, 512), 0, 18), # Pitch 79 | } -------------------------------------------------------------------------------- /widgets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/futscdav/Chunkmogrify/efdfdf3df0bb15e5e64de8575ab89baaaa9f5340/widgets/__init__.py -------------------------------------------------------------------------------- /widgets/mask_painter.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | import numpy as np 7 | 8 | from qtutil import * 9 | 10 | from PyQt5.QtWidgets import * 11 | from PyQt5.QtGui import * 12 | from PyQt5.QtCore import * 13 | 14 | from time import perf_counter as time 15 | from _C_canvas import canvas_to_masks 16 | 17 | 18 | # This is inefficient, but it was the quickest solution to the problem. 19 | # Ideally we would just paint along a path. 20 | # Alternative: use QCursor and scaled Pixmap, but it has its own set of problems. 21 | class EllipseBrush: 22 | def __init__(self, height, width, visual_color=QColor.fromRgb(25, 127, 255)): 23 | self.image = QImage(width, height, QImage.Format_RGBA8888) 24 | self.image.fill(Qt.transparent) 25 | self.visual_color = visual_color 26 | self.pen = QPen(self.visual_color) 27 | self.pen.setWidth(3) 28 | self.pen.setCapStyle(Qt.RoundCap) 29 | 30 | self.clear_pen = QPen(Qt.transparent) 31 | self.clear_pen.setWidth(6) 32 | self.clear_pen.setCapStyle(Qt.RoundCap) 33 | self.last_visual_arguments = None 34 | 35 | def clear_last_visual(self, into_painter=None): 36 | # very blunt method, could just draw a big transparent ellipse 37 | # self.image.fill(Qt.transparent) 38 | if self.last_visual_arguments is not None: 39 | if into_painter is None: 40 | painter = QPainter(self.image) 41 | else: 42 | painter = into_painter 43 | painter.setCompositionMode(QPainter.CompositionMode_Clear) 44 | painter.setRenderHints( QPainter.Antialiasing ) # AA for the brush layer 45 | painter.setPen(self.clear_pen) 46 | 47 | args = self.last_visual_arguments 48 | args[2] += 3; args[3] += 3 49 | painter.setBrush(Qt.transparent) 50 | painter.drawEllipse(*args) 51 | if into_painter is None: 52 | painter.end() 53 | 54 | def update_visual(self, mx, my, size): 55 | painter = QPainter(self.image) 56 | self.clear_last_visual(painter) 57 | painter.setCompositionMode(QPainter.CompositionMode_Source) 58 | 59 | painter.setRenderHints( QPainter.Antialiasing ) # AA for the brush layer 60 | painter.setPen(self.pen) 61 | half_radius = float(size) / 2 62 | self.last_visual_arguments = [mx * self.image.width() - half_radius, my * self.image.height() - half_radius, size, size] 63 | painter.drawEllipse(*self.last_visual_arguments) 64 | painter.end() 65 | 66 | def draw(self, qpainter, x, y, color, size): 67 | half_radius = float(size) / 2 68 | qpainter.setBrush(color) 69 | qpainter.setPen(color) 70 | qpainter.drawEllipse(x - half_radius, y - half_radius, size, size) 71 | 72 | # Careful about choice of colors, because there is a clear bug in Qt where the alpha is NOT entirely ignored even with the correct 73 | # CompositionMode. 74 | COLORS = [ 75 | (0, 0, 0, 0), 76 | (128, 128, 128, 90), # 1 77 | (128, 255, 128, 90), # 2 78 | (255, 128, 128, 90), # 3 79 | (128, 128, 255, 90), # 4 80 | (150, 51 , 201, 90), # 5 81 | (201, 150, 51 , 90), # 6 82 | (60 , 119, 150, 90), # 7 83 | (77 , 255, 238, 90), # 8 84 | (255, 244, 20 , 90), # 9 85 | (6 , 68 , 17 , 90), # 10 86 | (196, 11 , 0 , 90), # 11 87 | (99 , 48 , 11 , 90), # 12 88 | ] 89 | 90 | # Don't store the buf! 91 | mask_buf = None 92 | def canvas_to_numpy_mask(qt_image, colors): 93 | global mask_buf 94 | asnp = qim2np(qt_image, swapaxes=False) 95 | 96 | # This is the numpy way: 97 | # color_cmps = np_colors[None, None, :, :] 98 | # # (1024, 1024, 4, 1) == (1, 1, 4, 8) 99 | # cmps = np.equal(asnp[:, :, :, None], color_cmps, out=cmp_buf) 100 | # logical_masks = cmps.all(axis=2)[:, :, 1:] 101 | # masks = logical_masks 102 | 103 | # C++ optimized way: 104 | 105 | if mask_buf is None: 106 | masks = canvas_to_masks(asnp, colors) 107 | mask_buf = masks 108 | else: 109 | masks = canvas_to_masks(asnp, colors, output_buffer=mask_buf) 110 | return masks[:, :, 1:] 111 | 112 | # Don't store the buffers anywhere! 113 | canvas_buf = None 114 | def numpy_mask_to_numpy_canvas(numpy_mask): 115 | global canvas_buf 116 | if canvas_buf is None: canvas_buf = np.zeros((numpy_mask.shape[0], numpy_mask.shape[1], 4), dtype=np.uint8) 117 | else: pass # canvas_buf.fill(0.) 118 | # Use COLORS once there are multiple masks 119 | 120 | # The mask needs to have the index of the resulting color & be flattened 121 | numpy_mask = numpy_mask[:, :, :] * (np.arange(1, numpy_mask.shape[2] + 1))[None, None, :] 122 | numpy_mask = numpy_mask.sum(axis=2) 123 | 124 | np.choose( 125 | numpy_mask.astype(np.uint8)[:, :, None], 126 | COLORS, 127 | mode='clip', 128 | out=canvas_buf 129 | ) 130 | return canvas_buf 131 | 132 | class PainterWidget: 133 | def __init__(self, height, width, max_segments, reporting_fn): 134 | self.show_brush = False 135 | self.brush_size = 15 136 | self.brush = EllipseBrush(height, width) 137 | self.canvas = QImage(width, height, QImage.Format_RGBA8888) 138 | self.canvas.fill(Qt.transparent) 139 | self.update_callbacks = [] 140 | self.color_index = 1 141 | self.last_known_paint_pos = None 142 | self.reporting_fn = reporting_fn 143 | self.max_segments = max_segments 144 | self.np_colors = np.array(COLORS).astype(np.uint8)[:self.max_segments+1] 145 | if self.max_segments > len(COLORS) - 2: 146 | print(f'Requested {self.max_segments} masks, but only {len(COLORS) - 2} colors are defined. Setting max segments to {len(COLORS) - 2}') 147 | self.max_segments = len(COLORS) - 2 148 | 149 | self._hide = False 150 | self._enabled = True 151 | self._active_buttons = {} 152 | 153 | def get_volatile_masks(self): 154 | asnp = qim2np(self.canvas, swapaxes=False) 155 | masks = canvas_to_masks(asnp, self.np_colors) 156 | return asnp, masks 157 | 158 | def draw(self, qpainter, qevent): 159 | if self._hide: return 160 | # this actually only draws the overlay with brush 161 | qpainter.setOpacity(1) 162 | # This overlays current context with the actual painting content 163 | qpainter.drawImage(qevent.rect(), self.canvas) 164 | # This overlays the current context with the brush visual 165 | qpainter.drawImage(qevent.rect(), self.brush.image) 166 | 167 | def change_brush_size(self, amount): 168 | self.brush_size += amount 169 | self.brush_size = 1 if self.brush_size < 1 else self.brush_size 170 | self.update_visuals() 171 | 172 | def change_active_color(self, dst_idx): 173 | self.color_index = dst_idx 174 | # Active color is between 1 and N 175 | if self.color_index < 1: self.color_index = 1 176 | if self.color_index >= min(self.max_segments + 1, len(COLORS)): self.color_index = min(self.max_segments + 1, len(COLORS)) - 1 177 | self.reporting_fn(f"Current color index: {self.color_index} - {{{COLORS[self.color_index]}}}") 178 | 179 | def active_color(self): 180 | return QColor.fromRgb(*COLORS[self.color_index]) 181 | 182 | def rewrite_with_numpy_mask(self, target, writeback=False, actor=None): 183 | # Changes was initiated by gui and this is just a notification of change 184 | if actor == 'gui': return 185 | 186 | old_canvas = qim2np(self.canvas) 187 | old_canvas[:] = numpy_mask_to_numpy_canvas(target)[:] 188 | self.canvas = np2qim(old_canvas, do_copy=False) 189 | if writeback: 190 | for c in self.update_callbacks: 191 | c(target, actor='gui') 192 | 193 | # dx, dy in [0,1] 194 | # called on each mouse event, even when button is not down 195 | def paint_at(self, button, dy, dx): 196 | self.last_known_paint_pos = (dx, dy) 197 | self.update_visuals() 198 | 199 | csx, csy = dx * self.canvas.width(), dy * self.canvas.height() 200 | if button == MouseButton.LEFT: 201 | self._paint_at_pos(csx, csy, self.active_color()) 202 | if button == MouseButton.RIGHT: 203 | self._paint_at_pos(csx, csy, QColor.fromRgb(*COLORS[0])) 204 | if button == MouseButton.MIDDLE: 205 | self._pick_color(csx, csy) 206 | 207 | def update_mouse_position(self, dy, dx): 208 | self.last_known_paint_pos = (dx, dy) 209 | if self.enabled(): 210 | self.update_visuals() 211 | 212 | def update_visuals(self): 213 | if self.last_known_paint_pos is None: return 214 | self.brush.update_visual(self.last_known_paint_pos[0], self.last_known_paint_pos[1], self.brush_size) 215 | 216 | def clear_visuals(self): 217 | self.brush.clear_last_visual() 218 | 219 | def _paint_at_pos(self, x, y, color): 220 | qpainter = QPainter(self.canvas) 221 | qpainter.setCompositionMode(QPainter.CompositionMode_Source) # composition mode: Source (to prevent alpha multiplication) 222 | self.brush.draw(qpainter, x, y, color, self.brush_size) 223 | qpainter.end() 224 | # duplicate changes into the writeback buffer 225 | # this makes a copy after each paint call ! 226 | # Notify all interested objects. 227 | numpy_mask = canvas_to_numpy_mask(self.canvas, self.np_colors) 228 | for c in self.update_callbacks: 229 | c(numpy_mask, actor='gui') 230 | 231 | def _pick_color(self, x, y): 232 | c = self.canvas.pixelColor(x, y) 233 | idx = None 234 | for i, ref_c in enumerate(COLORS): 235 | if c.getRgb() == ref_c: 236 | idx = i 237 | if idx is None: raise RuntimeError(f"Color index not found (color has wrong representation of {c.getRgb()})") 238 | self.change_active_color(idx) 239 | 240 | # Enabled will be queried for event forwarding. 241 | def enabled(self): 242 | return self._enabled 243 | 244 | def toggle(self, to): 245 | self._enabled = to 246 | if (not to): 247 | # If enabled = False then draw will no longer be called and visual might stick around. 248 | self.clear_visuals() 249 | else: 250 | # if enabled = True then visual is not drawn and mouse position is not updated. 251 | self.update_visuals() 252 | 253 | def hide(self): 254 | self._hide = True 255 | 256 | def show(self): 257 | self._hide = False 258 | 259 | def mouse_enter(self): 260 | pass 261 | 262 | def mouse_leave(self): 263 | self.clear_visuals() 264 | 265 | def mouse_zoom(self, angle): 266 | modifiers = QApplication.keyboardModifiers() 267 | if modifiers == Qt.ShiftModifier: 268 | self.change_active_color(self.color_index + 1 if angle > 0 else self.color_index - 1) 269 | else: 270 | self.change_brush_size(angle) 271 | 272 | def mouse_down(self, y, x, btn): 273 | self._active_buttons[btn] = True 274 | self.paint_at(btn, y, x) 275 | 276 | def mouse_up(self, y, x, btn): 277 | self._active_buttons[btn] = None 278 | 279 | def mouse_moved(self, y, x): 280 | if self.enabled(): 281 | for (btn, active) in self._active_buttons.items(): 282 | if active: self.paint_at(btn, y, x) 283 | self.update_mouse_position(y, x) 284 | 285 | def key_down(self, key): 286 | if key == '+': 287 | self.change_active_color(self.color_index + 1) 288 | if key == '-': 289 | self.change_active_color(self.color_index - 1) 290 | -------------------------------------------------------------------------------- /widgets/workspace.py: -------------------------------------------------------------------------------- 1 | # 2 | # Author: David Futschik 3 | # Provided as part of the Chunkmogrify project, 2021. 4 | # 5 | 6 | import numpy as np 7 | 8 | from PyQt5.QtWidgets import * 9 | from PyQt5.QtGui import * 10 | from PyQt5.QtCore import * 11 | 12 | from qtutil import * 13 | 14 | class KeepArWidget(QWidget): 15 | def __init__(self, parent=None): 16 | super().__init__(parent) 17 | self.ar = 1. # 16 / 9 18 | 19 | def set_widget(self, widget): 20 | self.setLayout(QBoxLayout(QBoxLayout.LeftToRight, self)) 21 | self.layout().addItem(QSpacerItem(0, 0)) 22 | self.layout().addWidget(widget) 23 | self.layout().addItem(QSpacerItem(0, 0)) 24 | 25 | def set_ar(self, ar): 26 | self.ar = ar 27 | 28 | def resizeEvent(self, event): 29 | w, h = event.size().width(), event.size().height() 30 | 31 | if w / h > self.ar: 32 | self.layout().setDirection(QBoxLayout.LeftToRight) 33 | widget_stretch = h * self.ar 34 | outer_stretch = (w - widget_stretch) / 2 + 0.5 35 | else: 36 | self.layout().setDirection(QBoxLayout.TopToBottom) 37 | widget_stretch = w / self.ar 38 | outer_stretch = (h - widget_stretch) / 2 + 0.5 39 | 40 | self.layout().setStretch(0, outer_stretch) 41 | self.layout().setStretch(1, widget_stretch) 42 | self.layout().setStretch(2, outer_stretch) 43 | 44 | 45 | class WorkspaceWidget(QWidget): 46 | def __init__(self, app, initial_image, forward_widgets=[], parent=None): 47 | super().__init__(parent) 48 | self.pixmap = None 49 | self.resize_hint = 2048 50 | self.aspect_ratio = 1. # 16. / 9. 51 | self.resize_to_aspect = True 52 | self.painting_enabled = True 53 | self.setLayout(QBoxLayout(QBoxLayout.LeftToRight, self)) 54 | 55 | self.empty_pixmap_label = "Start by opening an image file. Go to " 56 | if initial_image: 57 | self.content = npy_loader(initial_image) 58 | self.has_content = True 59 | else: 60 | self.content = np.zeros((1024, 1024, 4)) 61 | hbox = QHBoxLayout() 62 | hbox.addStretch() 63 | hbox.addWidget(QLabel(self.empty_pixmap_label)) 64 | hbox.addStretch() 65 | self.layout().addLayout(hbox) 66 | self.placeholder_content = hbox 67 | self.has_content = False 68 | 69 | self.overlay = None 70 | self.overlay_enabled = True 71 | self.forward_widgets = forward_widgets 72 | 73 | self.inset_border_params = None 74 | self.scale_factor = 1 75 | 76 | policy = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred) 77 | policy.setHeightForWidth(True) 78 | self.setSizePolicy(policy) 79 | 80 | self._update_image() 81 | 82 | def set_content(self, image: np.ndarray, update=True): 83 | assert image.dtype == np.uint8, "Workspace widget expects content to be a uint8 numpy array." 84 | self.content = image 85 | if not self.has_content: # Did not previously have content and showed placeholder. 86 | self.layout().removeItem(self.placeholder_content) 87 | destroy_layout(self.placeholder_content) 88 | self.has_content = True 89 | if update: self._update_image() 90 | 91 | def set_overlay(self, overlay: np.ndarray, update=True): 92 | assert overlay.dtype == np.uint8, "Workspace widget expects overlay to be a uint8 numpy array." 93 | assert overlay.shape[2] == 4, "Workspace overlay must be a 4 dimensional image." 94 | self.overlay = overlay 95 | if update: self._update_image() 96 | 97 | def _update_image(self): 98 | if self.overlay is not None and self.overlay_enabled: 99 | alpha = self.overlay[:, :, 3:3] 100 | with_overlay = self.overlay * alpha + (1 - alpha) * self.content[:, :, 0:3] 101 | image = with_overlay 102 | else: 103 | image = self.content 104 | self.pixmap = QPixmap(np2qim(image)) 105 | self.update() 106 | 107 | def toggle_overlay(self, val): 108 | self.overlay_enabled = val 109 | self._update_image() 110 | 111 | def current_image_as_numpy(self): 112 | return qim2np(self.pixmap.toImage()).copy() 113 | 114 | def get_current(self, with_overlay): 115 | npy = self.content[:, :, 0:3] 116 | if with_overlay and self.overlay is not None: 117 | alpha = self.overlay[:, :, 3:3] 118 | npy = self.overlay * alpha + (1 - alpha) * self.content[:, :, 0:3] 119 | return npy 120 | 121 | def inset_border(self, pxs, color): 122 | self.inset_border_params = { 123 | 'px': pxs, 124 | 'color': color 125 | } 126 | self.update() 127 | 128 | def set_scale_factor(self, factor): 129 | self.scale_factor = factor 130 | self.update() 131 | 132 | def paintEvent(self, event): 133 | pixmap = self.pixmap 134 | 135 | qpainter = QPainter(self) 136 | qpainter.drawPixmap(event.rect(), pixmap) 137 | if self.inset_border_params: 138 | rect = event.rect() 139 | rect.adjust(0, 0, -1, -1) 140 | qpainter.setPen( 141 | QPen( 142 | QColor(*self.inset_border_params['color'], 255), 143 | self.inset_border_params['px']) 144 | ) 145 | qpainter.drawRect(rect) 146 | 147 | for w in self.forward_widgets: 148 | w.draw(qpainter, event) 149 | 150 | if self.resize_to_aspect: 151 | self.resize_to_aspect = False 152 | w, h = self.resize_to() 153 | qpainter.end() 154 | 155 | def is_pos_outside_bounds(self, pos): 156 | x, y = pos.x(), pos.y() 157 | mx, my = self.width(), self.height() 158 | if x < 0 or y < 0: return True 159 | if x >= mx or y >= my: return True 160 | return False 161 | 162 | def resize_to(self): 163 | # figure out what the limiting dimension is 164 | w, h = self.width(), self.height() 165 | w = (int(1.0 / self.height_to_width_ratio() * self.height())) 166 | w = int(w * self.scale_factor) 167 | h = int(h * self.scale_factor) 168 | return (w, h) 169 | 170 | def height_to_width_ratio(self): 171 | if self.pixmap is None: 172 | return self.aspect_ratio 173 | return float(self.pixmap.height()) / self.pixmap.width() 174 | 175 | def sizeHint(self): 176 | base_width = int(self.resize_hint * self.scale_factor) 177 | return QSize(base_width, self.heightForWidth(base_width)) 178 | 179 | def resizeEvent(self, event): 180 | if int(self.width() * self.height_to_width_ratio()) != self.height(): 181 | self.resize_to_aspect = True 182 | self.resize_to_aspect = True 183 | 184 | def heightForWidth(self, width): 185 | return int(self.height_to_width_ratio() * width) 186 | 187 | def enterEvent(self, event): 188 | self.setMouseTracking(True) 189 | self.setFocus() 190 | for w in self.forward_widgets: 191 | if w.enabled(): 192 | w.mouse_enter() 193 | self.update() 194 | 195 | def leaveEvent(self, event): 196 | self.setMouseTracking(False) 197 | for w in self.forward_widgets: 198 | if w.enabled(): 199 | w.mouse_leave() 200 | self.update() 201 | 202 | def mouseMoveEvent(self, event): 203 | if self.is_pos_outside_bounds(event.localPos()): 204 | # call this if you dont want to continue drawing when you return into the widget 205 | # self.stop_painting() 206 | return 207 | sx = (event.pos().x() / self.width()) 208 | sy = (event.pos().y() / self.height()) 209 | 210 | for w in self.forward_widgets: 211 | # if w.enabled(): 212 | # Right now, let's propagate mouse events even when disabled. The widget should decide. 213 | # This is a workaround so that if the forwarded widget gets enabled after being disabled, 214 | # it will still have a chance to know where the cursor is currently. 215 | w.mouse_moved(sy, sx) 216 | self.update() 217 | 218 | def wheelEvent(self, event): 219 | for w in self.forward_widgets: 220 | if w.enabled(): 221 | w.mouse_zoom(event.angleDelta().y() * 1/8) 222 | self.update() 223 | 224 | def mousePressEvent(self, event): 225 | btn = MouseButton(event.button()) 226 | 227 | sx = event.localPos().x() / self.width() 228 | sy = event.localPos().y() / self.height() 229 | 230 | for w in self.forward_widgets: 231 | if w.enabled(): 232 | w.mouse_down(sy, sx, btn) 233 | self.update() 234 | 235 | def mouseReleaseEvent(self, event): 236 | sx = event.localPos().x() / self.width() 237 | sy = event.localPos().y() / self.height() 238 | btn = MouseButton(event.button()) 239 | for w in self.forward_widgets: 240 | if w.enabled(): 241 | w.mouse_up(sy, sx, btn) 242 | self.update() 243 | 244 | def keyPressEvent(self, event): 245 | # This doesn't always work, and some keys are not represented, but overall it's enough 246 | key = event.text() 247 | for w in self.forward_widgets: 248 | if w.enabled(): 249 | w.key_down(key) 250 | self.update() 251 | --------------------------------------------------------------------------------