├── .gitignore ├── README.md ├── models ├── __init__.py ├── discriminator.py ├── encoders │ ├── __init__.py │ ├── helpers.py │ ├── model_irse.py │ └── psp_encoders.py ├── latent_codes_pool.py ├── networks_stylegan2.py ├── op │ ├── __init__.py │ ├── conv2d_gradfix.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── psp.py ├── stylegan3-main.code-workspace └── utils.py ├── run.sh ├── sketch ├── Thumbs.db ├── checkpoints │ └── stylegan_pretrain │ │ └── avg.pt ├── dataset_release │ └── latent │ │ ├── id_1.pt │ │ ├── id_10.pt │ │ ├── id_11.pt │ │ ├── id_12.pt │ │ ├── id_13.pt │ │ ├── id_14.pt │ │ ├── id_15.pt │ │ ├── id_2.pt │ │ ├── id_3.pt │ │ ├── id_4.pt │ │ ├── id_5.pt │ │ ├── id_6.pt │ │ ├── id_7.pt │ │ ├── id_8.pt │ │ └── id_9.pt ├── dnnlib │ ├── __init__.py │ └── util.py ├── experiments │ └── inference.json ├── generate.py ├── generate.sh ├── stylesketch_utils │ ├── conv.py │ ├── prepare_stylegan.py │ └── stylesketch.py └── torch_utils │ ├── __init__.py │ ├── custom_ops.py │ ├── misc.py │ ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── filtered_lrelu.cpp │ ├── filtered_lrelu.cu │ ├── filtered_lrelu.h │ ├── filtered_lrelu.py │ ├── filtered_lrelu_ns.cu │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py │ ├── persistence.py │ └── training_stats.py └── utils ├── __init__.py ├── alignment.py ├── common.py ├── data_utils.py ├── distributed.py ├── inception_utils.py ├── model_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # READ THIS BEFORE YOU REFACTOR ME 2 | # 3 | # setup.py uses the list of patterns in this file to decide 4 | # what to delete, but it's not 100% sound. So, for example, 5 | # if you delete aten/build/ because it's redundant with build/, 6 | # aten/build/ will stop being cleaned. So be careful when 7 | # refactoring this file! 8 | 9 | ## PyTorch 10 | 11 | .coverage 12 | coverage.xml 13 | .dmypy.json 14 | .gradle 15 | .hypothesis 16 | .mypy_cache 17 | .additional_ci_files 18 | /.extracted_scripts/ 19 | **/.pytorch_specified_test_cases.csv 20 | **/.pytorch-disabled-tests.json 21 | **/.pytorch-slow-tests.json 22 | */*.pyc 23 | */*.so* 24 | */**/__pycache__ 25 | */**/*.dylib* 26 | */**/*.pyc 27 | */**/*.pyd 28 | */**/*.so* 29 | */**/**/*.pyc 30 | */**/**/**/*.pyc 31 | */**/**/**/**/*.pyc 32 | aten/build/ 33 | aten/src/ATen/Config.h 34 | aten/src/ATen/cuda/CUDAConfig.h 35 | benchmarks/.data 36 | caffe2/cpp_test/ 37 | dist/ 38 | docs/build/ 39 | docs/cpp/src 40 | docs/src/**/* 41 | docs/cpp/build 42 | docs/cpp/source/api 43 | docs/cpp/source/html/ 44 | docs/cpp/source/latex/ 45 | docs/source/compile/generated/ 46 | docs/source/generated/ 47 | docs/source/compile/generated/ 48 | log 49 | usage_log.txt 50 | test-reports/ 51 | test/*.bak 52 | test/**/*.bak 53 | test/.coverage 54 | test/.hypothesis/ 55 | test/cpp/api/mnist 56 | test/custom_operator/model.pt 57 | test/jit_hooks/*.pt 58 | test/data/legacy_modules.t7 59 | test/data/*.pt 60 | test/forward_backward_compatibility/nightly_schemas.txt 61 | dropout_model.pt 62 | test/generated_type_hints_smoketest.py 63 | test/htmlcov 64 | test/cpp_extensions/install/ 65 | third_party/build/ 66 | tools/coverage_plugins_package/pip-wheel-metadata/ 67 | tools/shared/_utils_internal.py 68 | tools/fast_nvcc/wrap_nvcc.sh 69 | tools/fast_nvcc/wrap_nvcc.bat 70 | tools/fast_nvcc/tmp/ 71 | torch.egg-info/ 72 | torch/_C/__init__.pyi 73 | torch/_C/_nn.pyi 74 | torch/_C/_VariableFunctions.pyi 75 | torch/_VF.pyi 76 | torch/return_types.pyi 77 | torch/nn/functional.pyi 78 | torch/utils/data/datapipes/datapipe.pyi 79 | torch/csrc/autograd/generated/* 80 | torch/csrc/lazy/generated/*.[!m]* 81 | torch_compile_debug/ 82 | # Listed manually because some files in this directory are not generated 83 | torch/testing/_internal/generated/annotated_fn_args.py 84 | torch/testing/_internal/data/*.pt 85 | torch/csrc/api/include/torch/version.h 86 | torch/csrc/cudnn/cuDNN.cpp 87 | torch/csrc/generated 88 | torch/csrc/generic/TensorMethods.cpp 89 | torch/csrc/jit/generated/* 90 | torch/csrc/jit/fuser/config.h 91 | torch/csrc/nn/THCUNN.cpp 92 | torch/csrc/nn/THCUNN.cwrap 93 | torch/bin/ 94 | torch/cmake/ 95 | torch/lib/*.a* 96 | torch/lib/*.dll* 97 | torch/lib/*.exe* 98 | torch/lib/*.dylib* 99 | torch/lib/*.h 100 | torch/lib/*.lib 101 | torch/lib/*.pdb 102 | torch/lib/*.so* 103 | torch/lib/protobuf*.pc 104 | torch/lib/build 105 | torch/lib/caffe2/ 106 | torch/lib/cmake 107 | torch/lib/include 108 | torch/lib/pkgconfig 109 | torch/lib/protoc 110 | torch/lib/protobuf/ 111 | torch/lib/tmp_install 112 | torch/lib/torch_shm_manager 113 | torch/lib/site-packages/ 114 | torch/lib/python* 115 | torch/lib64 116 | torch/include/ 117 | torch/share/ 118 | torch/test/ 119 | torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h 120 | torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h 121 | torch/version.py 122 | minifier_launcher.py 123 | # Root level file used in CI to specify certain env configs. 124 | # E.g., see .circleci/config.yaml 125 | env 126 | .circleci/scripts/COMMIT_MSG 127 | scripts/release_notes/*.json 128 | sccache-stats*.json 129 | lint.json 130 | 131 | # These files get copied over on invoking setup.py 132 | torchgen/packaged/* 133 | !torchgen/packaged/README.md 134 | 135 | #pretrained 136 | sketch/checkpoints/stylegan_pretrain/*.pt 137 | sketch/model_dir/*.pth 138 | 139 | # IPython notebook checkpoints 140 | .ipynb_checkpoints 141 | 142 | # Editor temporaries 143 | *.swa 144 | *.swb 145 | *.swc 146 | *.swd 147 | *.swe 148 | *.swf 149 | *.swg 150 | *.swh 151 | *.swi 152 | *.swj 153 | *.swk 154 | *.swl 155 | *.swm 156 | *.swn 157 | *.swo 158 | *.swp 159 | *~ 160 | .~lock.* 161 | 162 | # macOS dir files 163 | .DS_Store 164 | 165 | # Ninja files 166 | .ninja_deps 167 | .ninja_log 168 | compile_commands.json 169 | *.egg-info/ 170 | docs/source/scripts/activation_images/ 171 | docs/source/scripts/quantization_backend_configs/ 172 | 173 | ## General 174 | 175 | # Compiled Object files 176 | *.slo 177 | *.lo 178 | *.o 179 | *.cuo 180 | *.obj 181 | 182 | # Compiled Dynamic libraries 183 | *.so 184 | *.dylib 185 | *.dll 186 | 187 | # Compiled Static libraries 188 | *.lai 189 | *.la 190 | *.a 191 | *.lib 192 | 193 | # Compiled protocol buffers 194 | *.pb.h 195 | *.pb.cc 196 | *_pb2.py 197 | 198 | # Compiled python 199 | *.pyc 200 | *.pyd 201 | 202 | # Compiled MATLAB 203 | *.mex* 204 | 205 | # IPython notebook checkpoints 206 | .ipynb_checkpoints 207 | 208 | # Editor temporaries 209 | *.swn 210 | *.swo 211 | *.swp 212 | *~ 213 | 214 | # NFS handle files 215 | **/.nfs* 216 | 217 | # Sublime Text settings 218 | *.sublime-workspace 219 | *.sublime-project 220 | 221 | # Eclipse Project settings 222 | *.*project 223 | .settings 224 | 225 | # QtCreator files 226 | *.user 227 | 228 | # PyCharm files 229 | .idea 230 | 231 | # GDB history 232 | .gdb_history 233 | 234 | ## Caffe2 235 | 236 | # build, distribute, and bins (+ python proto bindings) 237 | build/ 238 | # Allow tools/build/ for build support. 239 | !tools/build/ 240 | build_host_protoc 241 | build_android 242 | build_ios 243 | .build_debug/* 244 | .build_release/* 245 | .build_profile/* 246 | distribute/* 247 | *.testbin 248 | *.bin 249 | cmake_build 250 | .cmake_build 251 | gen 252 | .setuptools-cmake-build 253 | .pytest_cache 254 | aten/build/* 255 | 256 | # Bram 257 | plsdontbreak 258 | 259 | # Generated documentation 260 | docs/_site 261 | docs/gathered 262 | _site 263 | doxygen 264 | docs/dev 265 | 266 | # LevelDB files 267 | *.sst 268 | *.ldb 269 | LOCK 270 | CURRENT 271 | MANIFEST-* 272 | 273 | # generated version file 274 | caffe2/version.py 275 | 276 | # setup.py intermediates 277 | .eggs 278 | caffe2.egg-info 279 | MANIFEST 280 | 281 | # Atom/Watchman required file 282 | .watchmanconfig 283 | .watchman 284 | 285 | # Files generated by CLion 286 | cmake-build-debug 287 | 288 | # BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.) 289 | # 290 | # Below files are not deleted by "setup.py clean". 291 | 292 | # Downloaded bazel 293 | tools/bazel 294 | 295 | # Visual Studio Code files 296 | .vs 297 | /.vscode/* 298 | !/.vscode/extensions.json 299 | !/.vscode/settings_recommended.json 300 | 301 | # YouCompleteMe config file 302 | .ycm_extra_conf.py 303 | 304 | # Files generated when a patch is rejected 305 | *.orig 306 | *.rej 307 | 308 | # Files generated by ctags 309 | CTAGS 310 | GTAGS 311 | GRTAGS 312 | GSYMS 313 | GPATH 314 | tags 315 | TAGS 316 | 317 | 318 | # ccls file 319 | .ccls-cache/ 320 | 321 | # clang tooling storage location 322 | .clang-format-bin 323 | .clang-tidy-bin 324 | .lintbin 325 | 326 | # clangd background index 327 | .clangd/ 328 | .cache/ 329 | 330 | # bazel symlinks 331 | bazel-* 332 | 333 | # xla repo 334 | xla/ 335 | 336 | # direnv, posh-direnv 337 | .env 338 | .envrc 339 | .psenvrc 340 | 341 | # generated shellcheck directories 342 | .shellcheck_generated*/ 343 | 344 | # zip archives 345 | *.zip 346 | 347 | # core dump files 348 | **/core.[1-9]* 349 | 350 | # Generated if you use the pre-commit script for clang-tidy 351 | pr.diff 352 | 353 | # coverage files 354 | */**/.coverage.* 355 | 356 | # buck generated files 357 | .buckd/ 358 | .lsp-buck-out/ 359 | .lsp.buckd/ 360 | buck-out/ 361 | 362 | # Downloaded libraries 363 | third_party/ruy/ 364 | third_party/glog/ 365 | 366 | # Virtualenv 367 | venv/ 368 | 369 | # Log files 370 | *.log 371 | sweep/ 372 | 373 | # Android build artifacts 374 | android/pytorch_android/.cxx 375 | android/pytorch_android_torchvision/.cxx 376 | 377 | # Pyre configs (for internal usage) 378 | .pyre_configuration 379 | .pyre_configuration.codenav 380 | .arcconfig 381 | .stable_pyre_client 382 | .pyre_client -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stylized Face Sketch Extraction via Generative Prior with Limited Data (EUROGRAPHICS 2024) 2 | 3 | ### [EG2024] Official repository of StyleSketch [[SKSF-A](https://github.com/kwanyun/SKSF-A)] [[Project Page](https://kwanyun.github.io/stylesketch_project/)] [[Paper](https://arxiv.org/abs/2403.11263)] 4 | ![teaser2](https://github.com/kwanyun/StyleSketch/assets/68629563/e5368677-fbd4-4942-9385-ed7cc14de603) 5 | 6 | ### Getting Started 7 | * install dependency 8 | ```bash 9 | bash run.sh 10 | ``` 11 | * Put styleGAN related checkpoints folder in stylesketch/sketch folder 12 | ex) stylesketch/sketch/checkpoints/stylegan_pretrain 13 | 14 | https://drive.google.com/file/d/1X--a491Q6reEBV50XfyYqQ86yDxI44nd/view?usp=drive_link 15 | 16 | 17 | * Put pretrained StyleSketch weights in model_dir 18 | ex) stylesketch/sketch/model_dir 19 | https://drive.google.com/file/d/17AgaRzSwXi3c5tmTZztrGGifyHGKrrQu/view?usp=drive_link 20 | 21 | 22 | ### How to inference Scripts 23 | Move to sketch folder and run generate.py with the style to extract 24 | ```bash 25 | cd sketch 26 | python generate.py --train_data sketch_MJ 27 | python generate.py --train_data pencil_sj 28 | ``` 29 | ### How to make the w^+ code to extract sketches? 30 | In our experiment, we used [e4e](https://github.com/omertov/encoder4editing) followed by optimization. This can be replaced by different inversion methods. 31 | 32 | 33 | ### SKSF-A Sketch Data 34 | SKSF-A consists of seven distinct styles drawn by professional artists, each containing 134 identities and corresponding sketches. 35 | 36 | ### [SKSF-A](https://github.com/kwanyun/SKSF-A) 37 | 38 | ### Acknowledgments 39 | our codes were borrowed from [DatasetGAN](https://github.com/nv-tlabs/datasetGAN_release) 40 | 41 | ### If you use this code or SKSF-A for your research, please cite our paper: 42 | ```bash 43 | @inproceedings{yun2024stylized, 44 | title={Stylized Face Sketch Extraction via Generative Prior with Limited Data}, 45 | author={Yun, Kwan and Seo, Kwanggyoon and Seo, Chang Wook and Yoon, Soyeon and Kim, Seongcheol and Ji, Soohyun and Ashtari, Amirsaman and Noh, Junyong}, 46 | booktitle={Computer Graphics Forum}, 47 | pages={e15045}, 48 | year={2024}, 49 | organization={Wiley Online Library} 50 | } 51 | ``` 52 | 53 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/models/__init__.py -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class LatentCodesDiscriminator(nn.Module): 5 | def __init__(self, style_dim, n_mlp): 6 | super().__init__() 7 | 8 | self.style_dim = style_dim 9 | 10 | layers = [] 11 | for i in range(n_mlp-1): 12 | layers.append( 13 | nn.Linear(style_dim, style_dim) 14 | ) 15 | layers.append(nn.LeakyReLU(0.2)) 16 | layers.append(nn.Linear(512, 1)) 17 | self.mlp = nn.Sequential(*layers) 18 | 19 | def forward(self, w): 20 | return self.mlp(w) 21 | -------------------------------------------------------------------------------- /models/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/models/encoders/__init__.py -------------------------------------------------------------------------------- /models/encoders/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 5 | 6 | """ 7 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, input): 13 | return input.view(input.size(0), -1) 14 | 15 | 16 | def l2_norm(input, axis=1): 17 | norm = torch.norm(input, 2, axis, True) 18 | output = torch.div(input, norm) 19 | return output 20 | 21 | 22 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 23 | """ A named tuple describing a ResNet block. """ 24 | 25 | 26 | def get_block(in_channel, depth, num_units, stride=2): 27 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 28 | 29 | 30 | def get_blocks(num_layers): 31 | if num_layers == 50: 32 | blocks = [ 33 | get_block(in_channel=64, depth=64, num_units=3), 34 | get_block(in_channel=64, depth=128, num_units=4), 35 | get_block(in_channel=128, depth=256, num_units=14), 36 | get_block(in_channel=256, depth=512, num_units=3) 37 | ] 38 | elif num_layers == 100: 39 | blocks = [ 40 | get_block(in_channel=64, depth=64, num_units=3), 41 | get_block(in_channel=64, depth=128, num_units=13), 42 | get_block(in_channel=128, depth=256, num_units=30), 43 | get_block(in_channel=256, depth=512, num_units=3) 44 | ] 45 | elif num_layers == 152: 46 | blocks = [ 47 | get_block(in_channel=64, depth=64, num_units=3), 48 | get_block(in_channel=64, depth=128, num_units=8), 49 | get_block(in_channel=128, depth=256, num_units=36), 50 | get_block(in_channel=256, depth=512, num_units=3) 51 | ] 52 | else: 53 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 54 | return blocks 55 | 56 | 57 | class SEModule(Module): 58 | def __init__(self, channels, reduction): 59 | super(SEModule, self).__init__() 60 | self.avg_pool = AdaptiveAvgPool2d(1) 61 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 62 | self.relu = ReLU(inplace=True) 63 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 64 | self.sigmoid = Sigmoid() 65 | 66 | def forward(self, x): 67 | module_input = x 68 | x = self.avg_pool(x) 69 | x = self.fc1(x) 70 | x = self.relu(x) 71 | x = self.fc2(x) 72 | x = self.sigmoid(x) 73 | return module_input * x 74 | 75 | 76 | class bottleneck_IR(Module): 77 | def __init__(self, in_channel, depth, stride): 78 | super(bottleneck_IR, self).__init__() 79 | if in_channel == depth: 80 | self.shortcut_layer = MaxPool2d(1, stride) 81 | else: 82 | self.shortcut_layer = Sequential( 83 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 84 | BatchNorm2d(depth) 85 | ) 86 | self.res_layer = Sequential( 87 | BatchNorm2d(in_channel), 88 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 89 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 90 | ) 91 | 92 | def forward(self, x): 93 | shortcut = self.shortcut_layer(x) 94 | res = self.res_layer(x) 95 | return res + shortcut 96 | 97 | 98 | class bottleneck_IR_SE(Module): 99 | def __init__(self, in_channel, depth, stride): 100 | super(bottleneck_IR_SE, self).__init__() 101 | if in_channel == depth: 102 | self.shortcut_layer = MaxPool2d(1, stride) 103 | else: 104 | self.shortcut_layer = Sequential( 105 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 106 | BatchNorm2d(depth) 107 | ) 108 | self.res_layer = Sequential( 109 | BatchNorm2d(in_channel), 110 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 111 | PReLU(depth), 112 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 113 | BatchNorm2d(depth), 114 | SEModule(depth, 16) 115 | ) 116 | 117 | def forward(self, x): 118 | shortcut = self.shortcut_layer(x) 119 | res = self.res_layer(x) 120 | return res + shortcut 121 | 122 | 123 | def _upsample_add(x, y): 124 | """Upsample and add two feature maps. 125 | Args: 126 | x: (Variable) top feature map to be upsampled. 127 | y: (Variable) lateral feature map. 128 | Returns: 129 | (Variable) added feature map. 130 | Note in PyTorch, when input size is odd, the upsampled feature map 131 | with `F.upsample(..., scale_factor=2, mode='nearest')` 132 | maybe not equal to the lateral feature map size. 133 | e.g. 134 | original input size: [N,_,15,15] -> 135 | conv2d feature map size: [N,_,8,8] -> 136 | upsampled feature map size: [N,_,16,16] 137 | So we choose bilinear upsample which supports arbitrary output sizes. 138 | """ 139 | _, _, H, W = y.size() 140 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 141 | -------------------------------------------------------------------------------- /models/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /models/encoders/psp_encoders.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 7 | 8 | from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add 9 | 10 | 11 | 12 | 13 | 14 | 15 | class ProgressiveStage(Enum): 16 | WTraining = 0 17 | Delta1Training = 1 18 | Delta2Training = 2 19 | Delta3Training = 3 20 | Delta4Training = 4 21 | Delta5Training = 5 22 | Delta6Training = 6 23 | Delta7Training = 7 24 | Delta8Training = 8 25 | Delta9Training = 9 26 | Delta10Training = 10 27 | Delta11Training = 11 28 | Delta12Training = 12 29 | Delta13Training = 13 30 | Delta14Training = 14 31 | Delta15Training = 15 32 | Delta16Training = 16 33 | Delta17Training = 17 34 | Inference = 18 35 | 36 | 37 | class EqualLinears(nn.Module): 38 | def __init__( 39 | self, in_dim, out_dim, bias=False, bias_init=0, lr_mul=1, activation=None 40 | ): 41 | super().__init__() 42 | 43 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 44 | 45 | if bias: 46 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 47 | 48 | else: 49 | self.bias = None 50 | 51 | self.activation = activation 52 | 53 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 54 | self.lr_mul = lr_mul 55 | 56 | def forward(self, input): 57 | if self.activation: 58 | out = F.linear(input, self.weight * self.scale) 59 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 60 | 61 | else: 62 | out = F.linear( 63 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 64 | ) 65 | 66 | 67 | return out 68 | 69 | def __repr__(self): 70 | return ( 71 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' 72 | ) 73 | 74 | class GradualStyleBlock(Module): 75 | def __init__(self, in_c, out_c, spatial): 76 | super(GradualStyleBlock, self).__init__() 77 | self.out_c = out_c 78 | self.spatial = spatial 79 | num_pools = int(np.log2(spatial)) 80 | modules = [] 81 | modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 82 | nn.LeakyReLU()] 83 | for i in range(num_pools - 1): 84 | modules += [ 85 | Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 86 | nn.LeakyReLU() 87 | ] 88 | self.convs = nn.Sequential(*modules) 89 | self.linear = EqualLinears(out_c, out_c, bias=True,lr_mul=1) 90 | 91 | def forward(self, x): 92 | x = self.convs(x) 93 | x = x.view(-1, self.out_c) 94 | x = self.linear(x) 95 | return x 96 | 97 | 98 | class GradualStyleEncoder(Module): 99 | def __init__(self, num_layers, mode='ir', opts=None): 100 | super(GradualStyleEncoder, self).__init__() 101 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 102 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 103 | blocks = get_blocks(num_layers) 104 | if mode == 'ir': 105 | unit_module = bottleneck_IR 106 | elif mode == 'ir_se': 107 | unit_module = bottleneck_IR_SE 108 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 109 | BatchNorm2d(64), 110 | PReLU(64)) 111 | modules = [] 112 | for block in blocks: 113 | for bottleneck in block: 114 | modules.append(unit_module(bottleneck.in_channel, 115 | bottleneck.depth, 116 | bottleneck.stride)) 117 | self.body = Sequential(*modules) 118 | 119 | self.styles = nn.ModuleList() 120 | log_size = int(math.log(opts.stylegan_size, 2)) 121 | self.style_count = 2 * log_size - 2 122 | self.coarse_ind = 3 123 | self.middle_ind = 7 124 | for i in range(self.style_count): 125 | if i < self.coarse_ind: 126 | style = GradualStyleBlock(512, 512, 16) 127 | elif i < self.middle_ind: 128 | style = GradualStyleBlock(512, 512, 32) 129 | else: 130 | style = GradualStyleBlock(512, 512, 64) 131 | self.styles.append(style) 132 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 133 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 134 | 135 | def forward(self, x): 136 | x = self.input_layer(x) 137 | 138 | latents = [] 139 | modulelist = list(self.body._modules.values()) 140 | for i, l in enumerate(modulelist): 141 | x = l(x) 142 | if i == 6: 143 | c1 = x 144 | elif i == 20: 145 | c2 = x 146 | elif i == 23: 147 | c3 = x 148 | 149 | for j in range(self.coarse_ind): 150 | latents.append(self.styles[j](c3)) 151 | 152 | p2 = _upsample_add(c3, self.latlayer1(c2)) 153 | for j in range(self.coarse_ind, self.middle_ind): 154 | latents.append(self.styles[j](p2)) 155 | 156 | p1 = _upsample_add(p2, self.latlayer2(c1)) 157 | for j in range(self.middle_ind, self.style_count): 158 | latents.append(self.styles[j](p1)) 159 | 160 | out = torch.stack(latents, dim=1) 161 | return out 162 | 163 | 164 | class Encoder4Editing(Module): 165 | def __init__(self, num_layers, mode='ir', opts=None): 166 | super(Encoder4Editing, self).__init__() 167 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 168 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 169 | blocks = get_blocks(num_layers) 170 | if mode == 'ir': 171 | unit_module = bottleneck_IR 172 | elif mode == 'ir_se': 173 | unit_module = bottleneck_IR_SE 174 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 175 | BatchNorm2d(64), 176 | PReLU(64)) 177 | modules = [] 178 | for block in blocks: 179 | for bottleneck in block: 180 | modules.append(unit_module(bottleneck.in_channel, 181 | bottleneck.depth, 182 | bottleneck.stride)) 183 | self.body = Sequential(*modules) 184 | 185 | self.styles = nn.ModuleList() 186 | log_size = int(math.log(opts.stylegan_size, 2)) 187 | self.style_count = 2 * log_size - 2 188 | self.coarse_ind = 3 189 | self.middle_ind = 7 190 | 191 | for i in range(self.style_count): 192 | if i < self.coarse_ind: 193 | style = GradualStyleBlock(512, 512, 16) 194 | elif i < self.middle_ind: 195 | style = GradualStyleBlock(512, 512, 32) 196 | else: 197 | style = GradualStyleBlock(512, 512, 64) 198 | self.styles.append(style) 199 | 200 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 201 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 202 | 203 | self.progressive_stage = ProgressiveStage.Inference 204 | 205 | def get_deltas_starting_dimensions(self): 206 | ''' Get a list of the initial dimension of every delta from which it is applied ''' 207 | return list(range(self.style_count)) # Each dimension has a delta applied to it 208 | 209 | def set_progressive_stage(self, new_stage: ProgressiveStage): 210 | self.progressive_stage = new_stage 211 | print('Changed progressive stage to: ', new_stage) 212 | 213 | def forward(self, x): 214 | x = self.input_layer(x) 215 | 216 | modulelist = list(self.body._modules.values()) 217 | for i, l in enumerate(modulelist): 218 | x = l(x) 219 | if i == 6: 220 | c1 = x 221 | elif i == 20: 222 | c2 = x 223 | elif i == 23: 224 | c3 = x 225 | 226 | # Infer main W and duplicate it 227 | w0 = self.styles[0](c3) 228 | w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) 229 | stage = self.progressive_stage.value 230 | features = c3 231 | for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas 232 | if i == self.coarse_ind: 233 | p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features 234 | features = p2 235 | elif i == self.middle_ind: 236 | p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features 237 | features = p1 238 | delta_i = self.styles[i](features) 239 | w[:, i] += delta_i 240 | return w 241 | 242 | 243 | class BackboneEncoderUsingLastLayerIntoW(Module): 244 | def __init__(self, num_layers, mode='ir', opts=None): 245 | super(BackboneEncoderUsingLastLayerIntoW, self).__init__() 246 | print('Using BackboneEncoderUsingLastLayerIntoW') 247 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 248 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 249 | blocks = get_blocks(num_layers) 250 | if mode == 'ir': 251 | unit_module = bottleneck_IR 252 | elif mode == 'ir_se': 253 | unit_module = bottleneck_IR_SE 254 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 255 | BatchNorm2d(64), 256 | PReLU(64)) 257 | self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) 258 | self.linear = EqualLinears(512, 512, bias=True, lr_mul=1) 259 | modules = [] 260 | for block in blocks: 261 | for bottleneck in block: 262 | modules.append(unit_module(bottleneck.in_channel, 263 | bottleneck.depth, 264 | bottleneck.stride)) 265 | self.body = Sequential(*modules) 266 | log_size = int(math.log(opts.stylegan_size, 2)) 267 | self.style_count = 2 * log_size - 2 268 | 269 | def forward(self, x): 270 | x = self.input_layer(x) 271 | x = self.body(x) 272 | x = self.output_pool(x) 273 | x = x.view(-1, 512) 274 | x = self.linear(x) 275 | return x.repeat(self.style_count, 1, 1).permute(1, 0, 2) -------------------------------------------------------------------------------- /models/latent_codes_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class LatentCodesPool: 6 | """This class implements latent codes buffer that stores previously generated w latent codes. 7 | This buffer enables us to update discriminators using a history of generated w's 8 | rather than the ones produced by the latest encoder. 9 | """ 10 | 11 | def __init__(self, pool_size): 12 | """Initialize the ImagePool class 13 | Parameters: 14 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 15 | """ 16 | self.pool_size = pool_size 17 | if self.pool_size > 0: # create an empty pool 18 | self.num_ws = 0 19 | self.ws = [] 20 | 21 | def query(self, ws): 22 | """Return w's from the pool. 23 | Parameters: 24 | ws: the latest generated w's from the generator 25 | Returns w's from the buffer. 26 | By 50/100, the buffer will return input w's. 27 | By 50/100, the buffer will return w's previously stored in the buffer, 28 | and insert the current w's to the buffer. 29 | """ 30 | if self.pool_size == 0: # if the buffer size is 0, do nothing 31 | return ws 32 | return_ws = [] 33 | for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512) 34 | # w = torch.unsqueeze(image.data, 0) 35 | if w.ndim == 2: 36 | i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate 37 | w = w[i] 38 | self.handle_w(w, return_ws) 39 | return_ws = torch.stack(return_ws, 0) # collect all the images and return 40 | return return_ws 41 | 42 | def handle_w(self, w, return_ws): 43 | if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer 44 | self.num_ws = self.num_ws + 1 45 | self.ws.append(w) 46 | return_ws.append(w) 47 | else: 48 | p = random.uniform(0, 1) 49 | if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer 50 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 51 | tmp = self.ws[random_id].clone() 52 | self.ws[random_id] = w 53 | return_ws.append(tmp) 54 | else: # by another 50% chance, the buffer will return the current image 55 | return_ws.append(w) 56 | -------------------------------------------------------------------------------- /models/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /models/op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /models/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input.contiguous(), 51 | gradgrad_bias, 52 | out, 53 | 3, 54 | 1, 55 | ctx.negative_slope, 56 | ctx.scale, 57 | ) 58 | 59 | return gradgrad_out, None, None, None, None 60 | 61 | 62 | class FusedLeakyReLUFunction(Function): 63 | @staticmethod 64 | def forward(ctx, input, bias, negative_slope, scale): 65 | empty = input.new_empty(0) 66 | 67 | ctx.bias = bias is not None 68 | 69 | if bias is None: 70 | bias = empty 71 | 72 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 73 | ctx.save_for_backward(out) 74 | ctx.negative_slope = negative_slope 75 | ctx.scale = scale 76 | 77 | return out 78 | 79 | @staticmethod 80 | def backward(ctx, grad_output): 81 | out, = ctx.saved_tensors 82 | 83 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 84 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 85 | ) 86 | 87 | if not ctx.bias: 88 | grad_bias = None 89 | 90 | return grad_input, grad_bias, None, None 91 | 92 | 93 | class FusedLeakyReLU(nn.Module): 94 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 95 | super().__init__() 96 | 97 | if bias: 98 | self.bias = nn.Parameter(torch.zeros(channel)) 99 | 100 | else: 101 | self.bias = None 102 | 103 | self.negative_slope = negative_slope 104 | self.scale = scale 105 | 106 | def forward(self, input): 107 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 108 | 109 | 110 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 111 | if input.device.type == "cpu": 112 | if bias is not None: 113 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 114 | return ( 115 | F.leaky_relu( 116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 117 | ) 118 | * scale 119 | ) 120 | 121 | else: 122 | return F.leaky_relu(input, negative_slope=0.2) * scale 123 | 124 | else: 125 | return FusedLeakyReLUFunction.apply( 126 | input.contiguous(), bias, negative_slope, scale 127 | ) 128 | -------------------------------------------------------------------------------- /models/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 6 | const torch::Tensor &bias, 7 | const torch::Tensor &refer, int act, int grad, 8 | float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) \ 11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | torch::Tensor fused_bias_act(const torch::Tensor &input, 19 | const torch::Tensor &bias, 20 | const torch::Tensor &refer, int act, int grad, 21 | float alpha, float scale) { 22 | CHECK_INPUT(input); 23 | CHECK_INPUT(bias); 24 | 25 | at::DeviceGuard guard(input.device()); 26 | 27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 28 | } 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 32 | } -------------------------------------------------------------------------------- /models/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | #include 16 | #include 17 | 18 | template 19 | static __global__ void 20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b, 21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha, 22 | scalar_t scale, int loop_x, int size_x, int step_b, 23 | int size_b, int use_bias, int use_ref) { 24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 25 | 26 | scalar_t zero = 0.0; 27 | 28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; 29 | loop_idx++, xi += blockDim.x) { 30 | scalar_t x = p_x[xi]; 31 | 32 | if (use_bias) { 33 | x += p_b[(xi / step_b) % size_b]; 34 | } 35 | 36 | scalar_t ref = use_ref ? p_ref[xi] : zero; 37 | 38 | scalar_t y; 39 | 40 | switch (act * 10 + grad) { 41 | default: 42 | case 10: 43 | y = x; 44 | break; 45 | case 11: 46 | y = x; 47 | break; 48 | case 12: 49 | y = 0.0; 50 | break; 51 | 52 | case 30: 53 | y = (x > 0.0) ? x : x * alpha; 54 | break; 55 | case 31: 56 | y = (ref > 0.0) ? x : x * alpha; 57 | break; 58 | case 32: 59 | y = 0.0; 60 | break; 61 | } 62 | 63 | out[xi] = y * scale; 64 | } 65 | } 66 | 67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 68 | const torch::Tensor &bias, 69 | const torch::Tensor &refer, int act, int grad, 70 | float alpha, float scale) { 71 | int curDevice = -1; 72 | cudaGetDevice(&curDevice); 73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 74 | 75 | auto x = input.contiguous(); 76 | auto b = bias.contiguous(); 77 | auto ref = refer.contiguous(); 78 | 79 | int use_bias = b.numel() ? 1 : 0; 80 | int use_ref = ref.numel() ? 1 : 0; 81 | 82 | int size_x = x.numel(); 83 | int size_b = b.numel(); 84 | int step_b = 1; 85 | 86 | for (int i = 1 + 1; i < x.dim(); i++) { 87 | step_b *= x.size(i); 88 | } 89 | 90 | int loop_x = 4; 91 | int block_size = 4 * 32; 92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 93 | 94 | auto y = torch::empty_like(x); 95 | 96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 97 | x.scalar_type(), "fused_bias_act_kernel", [&] { 98 | fused_bias_act_kernel<<>>( 99 | y.data_ptr(), x.data_ptr(), 100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha, 101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); 102 | }); 103 | 104 | return y; 105 | } -------------------------------------------------------------------------------- /models/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 5 | const torch::Tensor &kernel, int up_x, int up_y, 6 | int down_x, int down_y, int pad_x0, int pad_x1, 7 | int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) \ 10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 18 | int up_x, int up_y, int down_x, int down_y, int pad_x0, 19 | int pad_x1, int pad_y0, int pad_y1) { 20 | CHECK_INPUT(input); 21 | CHECK_INPUT(kernel); 22 | 23 | at::DeviceGuard guard(input.device()); 24 | 25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, 26 | pad_y0, pad_y1); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 31 | } -------------------------------------------------------------------------------- /models/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /models/psp.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('Agg') 4 | import torch 5 | from torch import nn 6 | from models.encoders import psp_encoders 7 | from models.networks_stylegan2 import Generator 8 | from configs.paths_config import model_paths 9 | 10 | 11 | def get_keys(d, name): 12 | if 'state_dict' in d: 13 | d = d['state_dict'] 14 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 15 | return d_filt 16 | 17 | 18 | class pSp(nn.Module): 19 | 20 | def __init__(self, opts): 21 | super(pSp, self).__init__() 22 | self.opts = opts 23 | # Define architecture 24 | self.encoder = self.set_encoder() 25 | self.decoder = Generator(512, 0,512,opts.stylegan_size,3) 26 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 27 | # Load weights if needed 28 | self.load_weights() 29 | 30 | def set_encoder(self): 31 | if self.opts.encoder_type == 'GradualStyleEncoder': 32 | encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) 33 | elif self.opts.encoder_type == 'Encoder4Editing': 34 | encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts) 35 | elif self.opts.encoder_type == 'SingleStyleCodeEncoder': 36 | encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts) 37 | else: 38 | raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) 39 | return encoder 40 | 41 | def load_weights(self): 42 | if self.opts.checkpoint_path is not None: 43 | print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path)) 44 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 45 | self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=False) 46 | self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) 47 | self.__load_latent_avg(ckpt) 48 | else: 49 | print('Loading encoders weights from irse50!') 50 | encoder_ckpt = torch.load(model_paths['ir_se50']) 51 | self.encoder.load_state_dict(encoder_ckpt, strict=False) 52 | print('Loading decoder weights from pretrained!') 53 | ckpt = torch.load(self.opts.stylegan_weights) 54 | self.decoder.load_state_dict(ckpt['g_ema'], strict=False) 55 | self.__load_latent_avg(ckpt, repeat=self.encoder.style_count) 56 | 57 | def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, 58 | inject_latent=None, return_latents=False, alpha=None): 59 | if input_code: 60 | codes = x 61 | else: 62 | codes = self.encoder(x) 63 | # normalize with respect to the center of an average face 64 | if self.opts.start_from_latent_avg: 65 | if codes.ndim == 2: 66 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] 67 | else: 68 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 69 | 70 | if latent_mask is not None: 71 | for i in latent_mask: 72 | if inject_latent is not None: 73 | if alpha is not None: 74 | codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] 75 | else: 76 | codes[:, i] = inject_latent[:, i] 77 | else: 78 | codes[:, i] = 0 79 | 80 | input_is_latent = not input_code 81 | images, result_latent = self.decoder([codes], 82 | input_is_latent=input_is_latent, 83 | randomize_noise=randomize_noise, 84 | return_latents=return_latents) 85 | 86 | if resize: 87 | images = self.face_pool(images) 88 | 89 | if return_latents: 90 | return images, result_latent 91 | else: 92 | return images 93 | 94 | def __load_latent_avg(self, ckpt, repeat=None): 95 | if 'latent_avg' in ckpt: 96 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 97 | elif self.opts.start_from_latent_avg: 98 | # Compute mean code based on a large number of latents (10,000 here) 99 | with torch.no_grad(): 100 | self.latent_avg = self.decoder.mean_latent(10000).to(self.opts.device) 101 | else: 102 | self.latent_avg = None 103 | if repeat is not None and self.latent_avg is not None: 104 | self.latent_avg = self.latent_avg.repeat(repeat, 1) -------------------------------------------------------------------------------- /models/stylegan3-main.code-workspace: -------------------------------------------------------------------------------- 1 | { 2 | "folders": [ 3 | { 4 | "path": "../.." 5 | } 6 | ] 7 | } -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | 2 | pip install numpy==1.21.6 3 | pip install click==7.1.1 4 | pip install pillow==8.3.1 5 | pip install scipy 6 | #pip install torch==1.9.1 7 | #pip install torchvision==0.10.1 8 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 9 | pip install requests==2.26.0 10 | pip install tqdm==4.62.2 11 | pip install ninja==1.10.2 12 | pip install matplotlib==3.4.2 13 | pip install imageio==2.9.0 14 | pip install imgui==1.3.0 15 | pip install glfw==2.2.0 16 | pip install opencv-python 17 | pip install pyopengl==3.1.5 18 | pip install cuda-python 19 | pip install ftfy regex tqdm 20 | apt-get -y update && apt-get -y upgrade 21 | apt-get -y install git 22 | pip install git+https://github.com/openai/CLIP.git 23 | apt-get -y install libgl1-mesa-glx 24 | #apt install cmake 25 | #pip install scikit-video opencv-python dlib 26 | pip install tensorboard 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /sketch/Thumbs.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/Thumbs.db -------------------------------------------------------------------------------- /sketch/checkpoints/stylegan_pretrain/avg.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/checkpoints/stylegan_pretrain/avg.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_1.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_10.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_10.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_11.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_11.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_12.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_12.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_13.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_13.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_14.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_14.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_15.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_15.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_2.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_3.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_4.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_5.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_6.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_6.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_7.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_7.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_8.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_8.pt -------------------------------------------------------------------------------- /sketch/dataset_release/latent/id_9.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwanyun/StyleSketch/0214b46d51a7cd6e50674b4e4461adb3f720aedc/sketch/dataset_release/latent/id_9.pt -------------------------------------------------------------------------------- /sketch/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /sketch/experiments/inference.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "model_dir/", 3 | "batch_size": 4, 4 | "dim": [1024, 1024], 5 | "max_training": 16, 6 | "ada_augment" : true, 7 | "d_regularize" :true, 8 | "annotation_data_from_w": true, 9 | "annotation_sketch_path": "./dataset_release/annotation/training_data/pseudoface_processed3/", 10 | "saved_latent": "./checkpoints/stylegan_pretrain/stylegan2-ffhq-1024x1024.pt", 11 | "image_latent_path": "./dataset_release/latent/", 12 | "stylegan_checkpoint": "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl", 13 | "stylesketch_path":"./model_dir/", 14 | "model_num": 1 15 | } -------------------------------------------------------------------------------- /sketch/generate.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 4 | import json 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | from stylesketch_utils.stylesketch import SketchGenerator,Discriminator 10 | from stylesketch_utils.prepare_stylegan import prepare_stylegan 11 | from PIL import Image 12 | from utils.utils import latent_to_image, oht_to_scalar_regression 13 | from tqdm import tqdm 14 | import argparse 15 | import glob 16 | 17 | 18 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | device_ids = [0,1] 20 | 21 | 22 | def parallelize(model): 23 | """ 24 | Distribute a model across multiple GPUs. 25 | """ 26 | if torch.cuda.device_count() > 1: 27 | print("Using", torch.cuda.device_count(), "GPUs") 28 | model = nn.DataParallel(model, device_ids=device_ids) 29 | return model 30 | 31 | 32 | def generate_data(args, name_train_style="Disney_sketch_MJ"): 33 | checkpoint_path = args['stylesketch_path'] 34 | result_path = os.path.join(checkpoint_path, 'samples' ) 35 | 36 | if os.path.exists(result_path): 37 | pass 38 | else: 39 | os.system('mkdir -p %s' % (result_path)) 40 | print('Experiment folder created at: %s' % (result_path)) 41 | 42 | g_all, avg_latent = prepare_stylegan(args['stylegan_checkpoint'],args['saved_latent']) 43 | 44 | 45 | generator = SketchGenerator() 46 | generator = generator.cuda() 47 | generator = parallelize(generator) 48 | 49 | checkpoint = torch.load(os.path.join(checkpoint_path, f'model_{name_train_style}.pth')) 50 | generator.load_state_dict(checkpoint['model_state_dict'],strict =True) 51 | generator.eval() 52 | 53 | with torch.no_grad(): 54 | #get sd2 latents to sketch 55 | latents_to_sketch = glob.glob(os.path.join(args['image_latent_path'],'*')) 56 | latents_to_sketch = [latent for latent in latents_to_sketch if latent.endswith('pt')] 57 | print( "num_sketches: ", len(latents_to_sketch)) 58 | 59 | for latent_dir in tqdm(latents_to_sketch): 60 | single_latent = torch.load(latent_dir) 61 | latent_input = single_latent.float().to(device) 62 | latent_name = latent_dir.split('/')[-1].split('.')[0] 63 | #extract features from sd model 64 | img, affine_layers = latent_to_image(g_all, latent_input, 3,dim=args['dim'][1],use_style_latents=args['annotation_data_from_w']) 65 | affine_layers.append(img.transpose(0,3,1,2)) 66 | affine_layers = [torch.from_numpy(x).type(torch.FloatTensor).to(device) for x in affine_layers] 67 | 68 | #generate sketch from features 69 | sketch_image = generator(affine_layers) 70 | 71 | #save sketch 72 | sketch_image = oht_to_scalar_regression(sketch_image.squeeze()) 73 | sketch_image = sketch_image.cpu().detach().numpy() 74 | image_label_name = os.path.join(result_path, f'Sketch_{latent_name}_{name_train_style}.png') 75 | sketch = Image.fromarray(sketch_image.astype('uint8')) 76 | sketch.save(image_label_name) 77 | 78 | #save images corresponding to sketch 79 | img = Image.fromarray(np.asarray(img.squeeze())) 80 | image_name_ori = os.path.join(result_path,f'Image_{latent_name}.png') 81 | img.save(image_name_ori) 82 | 83 | 84 | 85 | if __name__ == '__main__': 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument('--exp', type=str, default = "experiments/inference.json") 88 | parser.add_argument('--train_data', type=str) 89 | 90 | args = parser.parse_args() 91 | opts = json.load(open(args.exp, 'r')) 92 | 93 | 94 | 95 | generate_data(opts, args.train_data) -------------------------------------------------------------------------------- /sketch/generate.sh: -------------------------------------------------------------------------------- 1 | python generate.py --train_data ani_sketch_MJ 2 | python generate.py --train_data Disney_sketch_MJ 3 | python generate.py --train_data sketch_MJ 4 | python generate.py --train_data Rough_sketch_MJ 5 | python generate.py --train_data GHIBLI_sketch_MJ 6 | python generate.py --train_data pencil_sj 7 | python generate.py --train_data low_sj 8 | python generate.py --train_data apd 9 | python generate.py --train_data cufs -------------------------------------------------------------------------------- /sketch/stylesketch_utils/conv.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | import torch 4 | from torchvision import models 5 | import torchvision 6 | class conv_block(nn.Module): 7 | def __init__(self,ch_in,ch_out): 8 | super(conv_block,self).__init__() 9 | self.conv = nn.Sequential( 10 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=False), 11 | nn.BatchNorm2d(ch_out), 12 | nn.LeakyReLU(negative_slope=0.2,inplace=True), 13 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=False), 14 | nn.BatchNorm2d(ch_out), 15 | nn.LeakyReLU(negative_slope=0.2,inplace=True), 16 | ) 17 | def forward(self,x): 18 | x = self.conv(x) 19 | return x 20 | 21 | class conv_down(nn.Module): 22 | def __init__(self,ch_in,ch_out): 23 | super(conv_down,self).__init__() 24 | self.conv = nn.Sequential( 25 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=2,padding=1,bias=False), 26 | nn.BatchNorm2d(ch_out), 27 | nn.LeakyReLU(negative_slope=0.2,inplace=True), 28 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=False), 29 | nn.BatchNorm2d(ch_out), 30 | nn.LeakyReLU(negative_slope=0.2,inplace=True), 31 | ) 32 | def forward(self,x): 33 | x = self.conv(x) 34 | return x 35 | 36 | class conv_res(nn.Module): 37 | def __init__(self,ch_in,ch_out): 38 | super(conv_res,self).__init__() 39 | self.conv = nn.Sequential( 40 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=False), 41 | nn.BatchNorm2d(ch_out), 42 | nn.LeakyReLU(negative_slope=0.2), 43 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=False), 44 | nn.BatchNorm2d(ch_out), 45 | nn.LeakyReLU(negative_slope=0.2), 46 | ) 47 | def forward(self,x): 48 | c = self.conv(x) 49 | return c+x 50 | 51 | class half_res(nn.Module): 52 | def __init__(self,ch_in,ch_out): 53 | super(half_res,self).__init__() 54 | self.conv = nn.Sequential( 55 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=False), 56 | nn.BatchNorm2d(ch_out), 57 | nn.LeakyReLU(negative_slope=0.2), 58 | ) 59 | def forward(self,x): 60 | c = self.conv(x) 61 | return c+x 62 | 63 | class up_conv(nn.Module): 64 | def __init__(self,ch_in,ch_out): 65 | super(up_conv,self).__init__() 66 | self.up = nn.Sequential( 67 | nn.Upsample(scale_factor=2), 68 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=False), 69 | nn.BatchNorm2d(ch_out), 70 | nn.LeakyReLU(negative_slope=0.2,inplace=True), 71 | ) 72 | 73 | def forward(self,x): 74 | x = self.up(x) 75 | return x 76 | 77 | class Attention_block(nn.Module): 78 | def __init__(self, F_g, F_l, F_int): 79 | super(Attention_block, self).__init__() 80 | self.W_g = nn.Sequential( 81 | nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=False), 82 | nn.BatchNorm2d(F_int) 83 | ) 84 | 85 | self.W_x = nn.Sequential( 86 | nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=False), 87 | nn.BatchNorm2d(F_int) 88 | ) 89 | 90 | self.psi = nn.Sequential( 91 | nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=False), 92 | nn.BatchNorm2d(1), 93 | nn.Sigmoid() 94 | ) 95 | 96 | self.relu = nn.LeakyReLU(negative_slope=0.2,inplace=True) 97 | 98 | def forward(self, g, x): 99 | g1 = self.W_g(g) 100 | x1 = self.W_x(x) 101 | psi = self.relu(g1 + x1) 102 | psi = self.psi(psi) 103 | return x * psi 104 | 105 | 106 | class AttU_Net(nn.Module): 107 | def __init__(self, img_ch=3, output_ch=1): 108 | super(AttU_Net, self).__init__() 109 | 110 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 111 | 112 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=64) 113 | self.Conv2 = conv_block(ch_in=64, ch_out=128) 114 | self.Conv3 = conv_block(ch_in=128, ch_out=256) 115 | self.Conv4 = conv_block(ch_in=256, ch_out=512) 116 | self.Conv5 = conv_block(ch_in=512, ch_out=1024) 117 | 118 | self.Up5 = up_conv(ch_in=1024, ch_out=512) 119 | self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256) 120 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 121 | 122 | self.Up4 = up_conv(ch_in=512, ch_out=256) 123 | self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128) 124 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 125 | 126 | self.Up3 = up_conv(ch_in=256, ch_out=128) 127 | self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64) 128 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 129 | 130 | self.Up2 = up_conv(ch_in=128, ch_out=64) 131 | self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32) 132 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 133 | 134 | self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) 135 | self.sigmoid = nn.Sigmoid() 136 | def forward(self, x): 137 | # encoding path 138 | x1 = self.Conv1(x) 139 | 140 | x2 = self.Maxpool(x1) 141 | x2 = self.Conv2(x2) 142 | 143 | x3 = self.Maxpool(x2) 144 | x3 = self.Conv3(x3) 145 | 146 | x4 = self.Maxpool(x3) 147 | x4 = self.Conv4(x4) 148 | 149 | x5 = self.Maxpool(x4) 150 | x5 = self.Conv5(x5) 151 | 152 | # decoding + concat path 153 | d5 = self.Up5(x5) 154 | x4 = self.Att5(g=d5, x=x4) 155 | d5 = torch.cat((x4, d5), dim=1) 156 | d5 = self.Up_conv5(d5) 157 | 158 | d4 = self.Up4(d5) 159 | x3 = self.Att4(g=d4, x=x3) 160 | d4 = torch.cat((x3, d4), dim=1) 161 | d4 = self.Up_conv4(d4) 162 | 163 | d3 = self.Up3(d4) 164 | x2 = self.Att3(g=d3, x=x2) 165 | d3 = torch.cat((x2, d3), dim=1) 166 | d3 = self.Up_conv3(d3) 167 | 168 | d2 = self.Up2(d3) 169 | x1 = self.Att2(g=d2, x=x1) 170 | d2 = torch.cat((x1, d2), dim=1) 171 | d2 = self.Up_conv2(d2) 172 | 173 | d1 = self.Conv_1x1(d2) 174 | d1 = self.sigmoid(d1) 175 | 176 | return d1 177 | 178 | 179 | import torch 180 | import math 181 | import torch.nn as nn 182 | import torch.nn.functional as F 183 | class BasicConv(nn.Module): 184 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 185 | super(BasicConv, self).__init__() 186 | self.out_channels = out_planes 187 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 188 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 189 | self.relu = nn.ReLU() if relu else None 190 | def forward(self, x): 191 | x = self.conv(x) 192 | if self.bn is not None: 193 | x = self.bn(x) 194 | if self.relu is not None: 195 | x = self.relu(x) 196 | return x 197 | class Flatten(nn.Module): 198 | def forward(self, x): 199 | return x.view(x.size(0), -1) 200 | class ChannelGate(nn.Module): 201 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 202 | super(ChannelGate, self).__init__() 203 | self.gate_channels = gate_channels 204 | self.mlp = nn.Sequential( 205 | Flatten(), 206 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 207 | nn.ReLU(), 208 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 209 | ) 210 | self.pool_types = pool_types 211 | def forward(self, x): 212 | channel_att_sum = None 213 | for pool_type in self.pool_types: 214 | if pool_type=='avg': 215 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 216 | channel_att_raw = self.mlp( avg_pool ) 217 | elif pool_type=='max': 218 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 219 | channel_att_raw = self.mlp( max_pool ) 220 | elif pool_type=='lp': 221 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 222 | channel_att_raw = self.mlp( lp_pool ) 223 | elif pool_type=='lse': 224 | # LSE pool only 225 | lse_pool = logsumexp_2d(x) 226 | channel_att_raw = self.mlp( lse_pool ) 227 | if channel_att_sum is None: 228 | channel_att_sum = channel_att_raw 229 | else: 230 | channel_att_sum = channel_att_sum + channel_att_raw 231 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 232 | return x * scale 233 | def logsumexp_2d(tensor): 234 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 235 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 236 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 237 | return outputs 238 | 239 | class ChannelPool(nn.Module): 240 | def forward(self, x): 241 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 242 | 243 | class SpatialGate(nn.Module): 244 | def __init__(self): 245 | super(SpatialGate, self).__init__() 246 | kernel_size = 7 247 | self.compress = ChannelPool() 248 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) 249 | def forward(self, x): 250 | x_compress = self.compress(x) 251 | x_out = self.spatial(x_compress) 252 | scale = F.sigmoid(x_out) # broadcasting 253 | return x * scale 254 | 255 | class CBAM(nn.Module): 256 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 257 | super(CBAM, self).__init__() 258 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 259 | self.no_spatial=no_spatial 260 | if not no_spatial: 261 | self.SpatialGate = SpatialGate() 262 | def forward(self, x): 263 | x_out = self.ChannelGate(x) 264 | if not self.no_spatial: 265 | x_out = self.SpatialGate(x_out) 266 | return x_out -------------------------------------------------------------------------------- /sketch/stylesketch_utils/prepare_stylegan.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | from collections import OrderedDict 4 | import torch 5 | import torch.nn as nn 6 | import dnnlib 7 | from models.networks_stylegan2 import MappingNetwork,SynthesisNetwork 8 | import os 9 | 10 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 11 | 12 | def prepare_stylegan(_, saved_dir): 13 | avg_latent_dir = saved_dir.rsplit('/',1)[0]+"/avg.pt" 14 | avg_latent = torch.load(avg_latent_dir) 15 | 16 | g_all = nn.Sequential(OrderedDict([ 17 | ('g_mapping', MappingNetwork(512,0,512,18)), 18 | ('g_synthesis', SynthesisNetwork(512,1024,3)) 19 | ])) 20 | 21 | g_all.load_state_dict(torch.load(saved_dir, map_location=device)) 22 | g_all.eval().cuda() 23 | 24 | 25 | return g_all, avg_latent -------------------------------------------------------------------------------- /sketch/stylesketch_utils/stylesketch.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | from stylesketch_utils.conv import conv_block,conv_down,conv_res,half_res,up_conv,Attention_block,CBAM 6 | 7 | 8 | class SketchGenerator(nn.Module): 9 | def __init__(self, image_size=1024): 10 | super(SketchGenerator, self).__init__() 11 | self.image_size = image_size 12 | self.output_channels = 8 13 | assert self.image_size == 1024 or 512 or 256 , 'Resolution error' 14 | 15 | if self.image_size == 1024: 16 | self.tensor_resolutions = [ 1024,512,256,128,64,32,16,8] 17 | self.tensor_shape = [4,8,8,16,16,32,32,64,64,128,128,256,256,512,512,1024,1024,1024] 18 | 19 | elif self.image_size == 512: 20 | self.tensor_resolutions = [ 512, 512, 512, 512, 512, 512, 256, 256, 128, 128, 64, 64 ] 21 | 22 | 23 | self.image_resolutions = [ 32, 64, 128, 256, 512, 1024] 24 | self.image_channels_to_up=[self.output_channels*32,self.output_channels*32,self.output_channels*32,self.output_channels*32,self.output_channels*24,self.output_channels*8] 25 | 26 | #for upsample 27 | self.up_512_512_0 = up_conv(512,512) 28 | self.up_512_512_1 = up_conv(512,512) 29 | self.up_512_256 = up_conv(512,256) 30 | self.up_256_256 = up_conv(256,256) 31 | self.up_256_128 = up_conv(256,128) 32 | self.up_128_128 = up_conv(128,128) 33 | self.up_128_64 = up_conv(128,64) 34 | self.up_64_64 = up_conv(64,64) 35 | 36 | #attention block 37 | self.Att0_0 = Attention_block(F_g=512, F_l=512, F_int=256) 38 | self.Att0_1 = Attention_block(F_g=512, F_l=512, F_int=256) 39 | self.Att1_0 = Attention_block(F_g=256, F_l=256, F_int=128) 40 | self.Att1_1 = Attention_block(F_g=256, F_l=256, F_int=128) 41 | self.Att2_0 = Attention_block(F_g=128, F_l=128, F_int=64) 42 | self.Att2_1 = Attention_block(F_g=128, F_l=128, F_int=64) 43 | self.Att3_0 = Attention_block(F_g=64, F_l=64, F_int=32) 44 | self.Att3_1 = Attention_block(F_g=64, F_l=64, F_int=32) 45 | 46 | 47 | # 1x1 conv for input 48 | self.conv1x1_256_0 = nn.Conv2d(self.tensor_resolutions[1], self.tensor_resolutions[2], 1, stride=1) 49 | self.conv1x1_256_1 = nn.Conv2d(self.tensor_resolutions[1], self.tensor_resolutions[2], 1, stride=1) 50 | self.conv1x1_256_2 = nn.Conv2d(self.tensor_resolutions[1], self.tensor_resolutions[2], 1, stride=1) 51 | self.conv1x1_256_3 = nn.Conv2d(self.tensor_resolutions[1], self.tensor_resolutions[2], 1, stride=1) 52 | self.conv1x1_128_0 = nn.Conv2d(self.tensor_resolutions[1], self.tensor_resolutions[3], 1, stride=1) 53 | self.conv1x1_128_1 = nn.Conv2d(self.tensor_resolutions[1], self.tensor_resolutions[3], 1, stride=1) 54 | self.conv1x1_128_2 = nn.Conv2d(self.tensor_resolutions[1], self.tensor_resolutions[3], 1, stride=1) 55 | self.conv1x1_128_3 = nn.Conv2d(self.tensor_resolutions[1], self.tensor_resolutions[3], 1, stride=1) 56 | self.conv1x1_64_0 = nn.Conv2d(self.tensor_resolutions[2], self.tensor_resolutions[4], 1, stride=1) 57 | self.conv1x1_64_1 = nn.Conv2d(self.tensor_resolutions[2], self.tensor_resolutions[4], 1, stride=1) 58 | self.conv1x1_64_2 = nn.Conv2d(self.tensor_resolutions[3], self.tensor_resolutions[4], 1, stride=1) 59 | self.conv1x1_64_3 = nn.Conv2d(self.tensor_resolutions[3], self.tensor_resolutions[4], 1, stride=1) 60 | self.conv1x1_32_0 = nn.Conv2d(self.tensor_resolutions[4], self.tensor_resolutions[5], 1, stride=1) 61 | self.conv1x1_32_1 = nn.Conv2d(self.tensor_resolutions[4], self.tensor_resolutions[5], 1, stride=1) 62 | 63 | # 3x3 conv for calc and channel reduce 64 | self.conv0_0 = conv_block(1024, 512) 65 | self.conv0_1 = conv_block(1024, 512) 66 | self.conv1_0 = conv_block(512, 256) 67 | self.conv1_1 = conv_block(512, 256) 68 | self.conv2_0 = conv_block(256, 128) 69 | self.conv2_1 = conv_block(256, 128) 70 | self.conv3 = conv_block(128, 64) 71 | 72 | 73 | self.cbam = CBAM(131) 74 | self.Conv_1x1 = nn.Conv2d(131, 1, kernel_size=1, stride=1, padding=0) 75 | self.sigmoid = nn.Sigmoid() 76 | 77 | 78 | 79 | def forward(self, input): 80 | """ 81 | input: List[Torch.tensor] --> Torch.tensor shape : bs, c, h, w 82 | """ 83 | #level_0 = torch.cat([input[0], input[0]], dim = 1) #0&0 cat, 4 84 | level_0 = input[0] 85 | level_1 = torch.cat([self.conv1x1_256_0(input[1]), self.conv1x1_256_1(input[2])], dim = 1) #1&2 cat, 8 >512 86 | level_2 = torch.cat([self.conv1x1_256_2(input[3]), self.conv1x1_256_3(input[4])], dim = 1) #3&4 conv11, cat, 16 >512 87 | level_3 = torch.cat([self.conv1x1_128_0(input[5]), self.conv1x1_128_1(input[6])], dim = 1) #0&0 conv11, cat,32 256 88 | level_4 = torch.cat([self.conv1x1_128_2(input[7]), self.conv1x1_128_3(input[8])], dim = 1) #0&0 conv11, cat, 64 256 89 | level_5 = torch.cat([self.conv1x1_64_0(input[9]), self.conv1x1_64_1(input[10])], dim = 1) #0&0 conv11, cat,128 128 90 | level_6 = torch.cat([self.conv1x1_64_2(input[11]), self.conv1x1_64_3(input[12])], dim = 1) #0&0 conv11, cat, 256 128 91 | level_7 = torch.cat([self.conv1x1_32_0(input[13]), self.conv1x1_32_1(input[14])], dim = 1) #0&0 conv11, cat, 512 64 92 | level_8 = torch.cat([input[15], input[16]], dim = 1) #0&0 conv11, cat, 1024 93 | 94 | up_0 = self.up_512_512_0(level_0) #512x8x8 95 | att_0 = self.Att0_0(g=up_0, x=level_1) 96 | cat_0 = torch.cat((att_0, up_0), dim=1) #1024x8x8 97 | up_1 = self.conv0_0(cat_0) #512x8x8 98 | 99 | up_1 = self.up_512_512_1(up_1) 100 | att_1 = self.Att0_1(g=up_1, x=level_2) 101 | cat_1 = torch.cat((att_1, up_1), dim=1) #1024x16x16 102 | up_2 = self.conv0_1(cat_1) #512x16x16 103 | 104 | up_2 = self.up_512_256(up_2) #256x32x32 105 | att_2 = self.Att1_0(g=up_2, x=level_3) 106 | cat_2 = torch.cat((att_2, up_2), dim=1) #512x32x32 107 | up_3 = self.conv1_0(cat_2) #256x32x32 108 | 109 | up_3 = self.up_256_256(up_3) #256x64x64 110 | att_3 = self.Att1_1(g=up_3, x=level_4) 111 | cat_3 = torch.cat((att_3, up_3), dim=1) #512x64x64 112 | up_4 = self.conv1_1(cat_3) #256x64x64 113 | 114 | up_4 = self.up_256_128(up_4) #128x128x128 115 | att_4 = self.Att2_0(g=up_4, x=level_5) 116 | cat_4 = torch.cat((att_4, up_4), dim=1) #256x128x128 117 | up_5 = self.conv2_0(cat_4) #128x128x128 118 | 119 | up_5 = self.up_128_128(up_5) #128x256x256 120 | att_5 = self.Att2_1(g=up_5, x=level_6) 121 | cat_5 = torch.cat((att_5, up_5), dim=1) #256x256x256 122 | up_6 = self.conv2_1(cat_5) #128x256x256 123 | 124 | up_6 = self.up_128_64(up_6) #64x512x512 125 | att_6 = self.Att3_0(g=up_6, x=level_7) 126 | cat_6 = torch.cat((att_6, up_6), dim=1) #128x512x512 127 | up_7 = self.conv3(cat_6) #64x512x512 128 | 129 | up_7 = self.up_64_64(up_7) #64x1024x1024 130 | att_7 = self.Att3_1(g=up_7, x=level_8) 131 | cat_7 = torch.cat((att_7, up_7,input[17]), dim=1) #(128+3)x1024x1024 132 | cbam_out = self.cbam(cat_7) 133 | sketch = self.Conv_1x1(cbam_out) 134 | 135 | return self.sigmoid(sketch) 136 | 137 | class Discriminator(nn.Module): 138 | def __init__(self): 139 | super().__init__() 140 | 141 | self.layers = nn.Sequential( 142 | # input is (nc) x i x i 143 | nn.Conv2d(2, 32, 3, stride = 2, padding = 1,bias=True), 144 | nn.LeakyReLU(0.2, inplace=True), 145 | # state size. (ndf) x i/2 x i/2 146 | conv_down(32,64), 147 | half_res(64,64), 148 | # state size. (ndf) x i/4 x i/4 149 | conv_down(64,128), 150 | half_res(128,128), 151 | # state size. (ndf*2) x i/8 x i/8 152 | conv_down(128,256), 153 | conv_res(256,256), 154 | # state size. (ndf*4) x i/16 x i/16 155 | conv_down(256,512), 156 | conv_res(512,512), 157 | # state size. (ndf*4) x i/32 x i/32 158 | conv_down(512,512), 159 | conv_res(512,512), 160 | # state size. (ndf*8) x i/64 x i/64 161 | nn.Conv2d(512, 1, 1, 1, 0), 162 | nn.Sigmoid() 163 | ) 164 | 165 | def forward(self,a,b): 166 | img_input = torch.cat((a, b), 1) 167 | return self.layers(img_input) 168 | 169 | -------------------------------------------------------------------------------- /sketch/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /sketch/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files*/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /sketch/torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 68 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 69 | 70 | @contextlib.contextmanager 71 | def suppress_tracer_warnings(): 72 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 73 | warnings.filters.insert(0, flt) 74 | yield 75 | warnings.filters.remove(flt) 76 | 77 | #---------------------------------------------------------------------------- 78 | # Assert that the shape of a tensor matches the given list of integers. 79 | # None indicates that the size of a dimension is allowed to vary. 80 | # Performs symbolic assertion when used in torch.jit.trace(). 81 | 82 | def assert_shape(tensor, ref_shape): 83 | if tensor.ndim != len(ref_shape): 84 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 85 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 86 | if ref_size is None: 87 | pass 88 | elif isinstance(ref_size, torch.Tensor): 89 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 90 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 91 | elif isinstance(size, torch.Tensor): 92 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 93 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 94 | elif size != ref_size: 95 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 96 | 97 | #---------------------------------------------------------------------------- 98 | # Function decorator that calls torch.autograd.profiler.record_function(). 99 | 100 | def profiled_function(fn): 101 | def decorator(*args, **kwargs): 102 | with torch.autograd.profiler.record_function(fn.__name__): 103 | return fn(*args, **kwargs) 104 | decorator.__name__ = fn.__name__ 105 | return decorator 106 | 107 | #---------------------------------------------------------------------------- 108 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 109 | # indefinitely, shuffling items as it goes. 110 | 111 | class InfiniteSampler(torch.utils.data.Sampler): 112 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 113 | assert len(dataset) > 0 114 | assert num_replicas > 0 115 | assert 0 <= rank < num_replicas 116 | assert 0 <= window_size <= 1 117 | super().__init__(dataset) 118 | self.dataset = dataset 119 | self.rank = rank 120 | self.num_replicas = num_replicas 121 | self.shuffle = shuffle 122 | self.seed = seed 123 | self.window_size = window_size 124 | 125 | def __iter__(self): 126 | order = np.arange(len(self.dataset)) 127 | rnd = None 128 | window = 0 129 | if self.shuffle: 130 | rnd = np.random.RandomState(self.seed) 131 | rnd.shuffle(order) 132 | window = int(np.rint(order.size * self.window_size)) 133 | 134 | idx = 0 135 | while True: 136 | i = idx % order.size 137 | if idx % self.num_replicas == self.rank: 138 | yield order[i] 139 | if window >= 2: 140 | j = (i - rnd.randint(window)) % order.size 141 | order[i], order[j] = order[j], order[i] 142 | idx += 1 143 | 144 | #---------------------------------------------------------------------------- 145 | # Utilities for operating with torch.nn.Module parameters and buffers. 146 | 147 | def params_and_buffers(module): 148 | assert isinstance(module, torch.nn.Module) 149 | return list(module.parameters()) + list(module.buffers()) 150 | 151 | def named_params_and_buffers(module): 152 | assert isinstance(module, torch.nn.Module) 153 | return list(module.named_parameters()) + list(module.named_buffers()) 154 | 155 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 156 | assert isinstance(src_module, torch.nn.Module) 157 | assert isinstance(dst_module, torch.nn.Module) 158 | src_tensors = dict(named_params_and_buffers(src_module)) 159 | for name, tensor in named_params_and_buffers(dst_module): 160 | assert (name in src_tensors) or (not require_all) 161 | if name in src_tensors: 162 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 163 | 164 | #---------------------------------------------------------------------------- 165 | # Context manager for easily enabling/disabling DistributedDataParallel 166 | # synchronization. 167 | 168 | @contextlib.contextmanager 169 | def ddp_sync(module, sync): 170 | assert isinstance(module, torch.nn.Module) 171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 172 | yield 173 | else: 174 | with module.no_sync(): 175 | yield 176 | 177 | #---------------------------------------------------------------------------- 178 | # Check DistributedDataParallel consistency across processes. 179 | 180 | def check_ddp_consistency(module, ignore_regex=None): 181 | assert isinstance(module, torch.nn.Module) 182 | for name, tensor in named_params_and_buffers(module): 183 | fullname = type(module).__name__ + '.' + name 184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 185 | continue 186 | tensor = tensor.detach() 187 | if tensor.is_floating_point(): 188 | tensor = nan_to_num(tensor) 189 | other = tensor.clone() 190 | torch.distributed.broadcast(tensor=other, src=0) 191 | assert (tensor == other).all(), fullname 192 | 193 | #---------------------------------------------------------------------------- 194 | # Print summary table of module hierarchy. 195 | 196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 197 | assert isinstance(module, torch.nn.Module) 198 | assert not isinstance(module, torch.jit.ScriptModule) 199 | assert isinstance(inputs, (tuple, list)) 200 | 201 | # Register hooks. 202 | entries = [] 203 | nesting = [0] 204 | def pre_hook(_mod, _inputs): 205 | nesting[0] += 1 206 | def post_hook(mod, _inputs, outputs): 207 | nesting[0] -= 1 208 | if nesting[0] <= max_nesting: 209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 214 | 215 | # Run module. 216 | outputs = module(*inputs) 217 | for hook in hooks: 218 | hook.remove() 219 | 220 | # Identify unique outputs, parameters, and buffers. 221 | tensors_seen = set() 222 | for e in entries: 223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 227 | 228 | # Filter out redundant entries. 229 | if skip_redundant: 230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 231 | 232 | # Construct table. 233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 234 | rows += [['---'] * len(rows[0])] 235 | param_total = 0 236 | buffer_total = 0 237 | submodule_names = {mod: name for name, mod in module.named_modules()} 238 | for e in entries: 239 | name = '' if e.mod is module else submodule_names[e.mod] 240 | param_size = sum(t.numel() for t in e.unique_params) 241 | buffer_size = sum(t.numel() for t in e.unique_buffers) 242 | output_shapes = [str(list(t.shape)) for t in e.outputs] 243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 244 | rows += [[ 245 | name + (':0' if len(e.outputs) >= 2 else ''), 246 | str(param_size) if param_size else '-', 247 | str(buffer_size) if buffer_size else '-', 248 | (output_shapes + ['-'])[0], 249 | (output_dtypes + ['-'])[0], 250 | ]] 251 | for idx in range(1, len(e.outputs)): 252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 253 | param_total += param_size 254 | buffer_total += buffer_size 255 | rows += [['---'] * len(rows[0])] 256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 257 | 258 | # Print table. 259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 260 | print() 261 | for row in rows: 262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 263 | print() 264 | return outputs 265 | 266 | #---------------------------------------------------------------------------- 267 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import numpy as np 13 | import torch 14 | import dnnlib 15 | 16 | from .. import custom_ops 17 | from .. import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | activation_funcs = { 22 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 23 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 24 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 25 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 26 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 27 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 28 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 29 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 30 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 31 | } 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | _plugin = None 36 | _null_tensor = torch.empty([0]) 37 | 38 | def _init(): 39 | global _plugin 40 | if _plugin is None: 41 | _plugin = custom_ops.get_plugin( 42 | module_name='bias_act_plugin', 43 | sources=['bias_act.cpp', 'bias_act.cu'], 44 | headers=['bias_act.h'], 45 | source_dir=os.path.dirname(__file__), 46 | extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], 47 | ) 48 | return True 49 | 50 | #---------------------------------------------------------------------------- 51 | 52 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 53 | r"""Fused bias and activation function. 54 | 55 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 56 | and scales the result by `gain`. Each of the steps is optional. In most cases, 57 | the fused op is considerably more efficient than performing the same calculation 58 | using standard PyTorch ops. It supports first and second order gradients, 59 | but not third order gradients. 60 | 61 | Args: 62 | x: Input activation tensor. Can be of any shape. 63 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 64 | as `x`. The shape must be known, and it must match the dimension of `x` 65 | corresponding to `dim`. 66 | dim: The dimension in `x` corresponding to the elements of `b`. 67 | The value of `dim` is ignored if `b` is not specified. 68 | act: Name of the activation function to evaluate, or `"linear"` to disable. 69 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 70 | See `activation_funcs` for a full list. `None` is not allowed. 71 | alpha: Shape parameter for the activation function, or `None` to use the default. 72 | gain: Scaling factor for the output tensor, or `None` to use default. 73 | See `activation_funcs` for the default scaling of each activation function. 74 | If unsure, consider specifying 1. 75 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 76 | the clamping (default). 77 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 78 | 79 | Returns: 80 | Tensor of the same shape and datatype as `x`. 81 | """ 82 | assert isinstance(x, torch.Tensor) 83 | assert impl in ['ref', 'cuda'] 84 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 85 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 86 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @misc.profiled_function 91 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 92 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 93 | """ 94 | assert isinstance(x, torch.Tensor) 95 | assert clamp is None or clamp >= 0 96 | spec = activation_funcs[act] 97 | alpha = float(alpha if alpha is not None else spec.def_alpha) 98 | gain = float(gain if gain is not None else spec.def_gain) 99 | clamp = float(clamp if clamp is not None else -1) 100 | 101 | # Add bias. 102 | if b is not None: 103 | assert isinstance(b, torch.Tensor) and b.ndim == 1 104 | assert 0 <= dim < x.ndim 105 | assert b.shape[0] == x.shape[dim] 106 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 107 | 108 | # Evaluate activation function. 109 | alpha = float(alpha) 110 | x = spec.func(x, alpha=alpha) 111 | 112 | # Scale by gain. 113 | gain = float(gain) 114 | if gain != 1: 115 | x = x * gain 116 | 117 | # Clamp. 118 | if clamp >= 0: 119 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 120 | return x 121 | 122 | #---------------------------------------------------------------------------- 123 | 124 | _bias_act_cuda_cache = dict() 125 | 126 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 127 | """Fast CUDA implementation of `bias_act()` using custom ops. 128 | """ 129 | # Parse arguments. 130 | assert clamp is None or clamp >= 0 131 | spec = activation_funcs[act] 132 | alpha = float(alpha if alpha is not None else spec.def_alpha) 133 | gain = float(gain if gain is not None else spec.def_gain) 134 | clamp = float(clamp if clamp is not None else -1) 135 | 136 | # Lookup from cache. 137 | key = (dim, act, alpha, gain, clamp) 138 | if key in _bias_act_cuda_cache: 139 | return _bias_act_cuda_cache[key] 140 | 141 | # Forward op. 142 | class BiasActCuda(torch.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, x, b): # pylint: disable=arguments-differ 145 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format 146 | x = x.contiguous(memory_format=ctx.memory_format) 147 | b = b.contiguous() if b is not None else _null_tensor 148 | y = x 149 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 150 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 151 | ctx.save_for_backward( 152 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 153 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 154 | y if 'y' in spec.ref else _null_tensor) 155 | return y 156 | 157 | @staticmethod 158 | def backward(ctx, dy): # pylint: disable=arguments-differ 159 | dy = dy.contiguous(memory_format=ctx.memory_format) 160 | x, b, y = ctx.saved_tensors 161 | dx = None 162 | db = None 163 | 164 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 165 | dx = dy 166 | if act != 'linear' or gain != 1 or clamp >= 0: 167 | dx = BiasActCudaGrad.apply(dy, x, b, y) 168 | 169 | if ctx.needs_input_grad[1]: 170 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 171 | 172 | return dx, db 173 | 174 | # Backward op. 175 | class BiasActCudaGrad(torch.autograd.Function): 176 | @staticmethod 177 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 178 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format 179 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 180 | ctx.save_for_backward( 181 | dy if spec.has_2nd_grad else _null_tensor, 182 | x, b, y) 183 | return dx 184 | 185 | @staticmethod 186 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 187 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 188 | dy, x, b, y = ctx.saved_tensors 189 | d_dy = None 190 | d_x = None 191 | d_b = None 192 | d_y = None 193 | 194 | if ctx.needs_input_grad[0]: 195 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 196 | 197 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 198 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 199 | 200 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 201 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 202 | 203 | return d_dy, d_x, d_b, d_y 204 | 205 | # Add to cache. 206 | _bias_act_cuda_cache[key] = BiasActCuda 207 | return BiasActCuda 208 | 209 | #---------------------------------------------------------------------------- 210 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import contextlib 13 | import torch 14 | from pkg_resources import parse_version 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | @contextlib.contextmanager 27 | def no_weight_gradients(disable=True): 28 | global weight_gradients_disabled 29 | old = weight_gradients_disabled 30 | if disable: 31 | weight_gradients_disabled = True 32 | yield 33 | weight_gradients_disabled = old 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 38 | if _should_use_custom_op(input): 39 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 40 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 41 | 42 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 43 | if _should_use_custom_op(input): 44 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 45 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _should_use_custom_op(input): 50 | assert isinstance(input, torch.Tensor) 51 | if (not enabled) or (not torch.backends.cudnn.enabled): 52 | return False 53 | if _use_pytorch_1_11_api: 54 | # The work-around code doesn't work on PyTorch 1.11.0 onwards 55 | return False 56 | if input.device.type != 'cuda': 57 | return False 58 | return True 59 | 60 | def _tuple_of_ints(xs, ndim): 61 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 62 | assert len(xs) == ndim 63 | assert all(isinstance(x, int) for x in xs) 64 | return xs 65 | 66 | #---------------------------------------------------------------------------- 67 | 68 | _conv2d_gradfix_cache = dict() 69 | _null_tensor = torch.empty([0]) 70 | 71 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 72 | # Parse arguments. 73 | ndim = 2 74 | weight_shape = tuple(weight_shape) 75 | stride = _tuple_of_ints(stride, ndim) 76 | padding = _tuple_of_ints(padding, ndim) 77 | output_padding = _tuple_of_ints(output_padding, ndim) 78 | dilation = _tuple_of_ints(dilation, ndim) 79 | 80 | # Lookup from cache. 81 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 82 | if key in _conv2d_gradfix_cache: 83 | return _conv2d_gradfix_cache[key] 84 | 85 | # Validate arguments. 86 | assert groups >= 1 87 | assert len(weight_shape) == ndim + 2 88 | assert all(stride[i] >= 1 for i in range(ndim)) 89 | assert all(padding[i] >= 0 for i in range(ndim)) 90 | assert all(dilation[i] >= 0 for i in range(ndim)) 91 | if not transpose: 92 | assert all(output_padding[i] == 0 for i in range(ndim)) 93 | else: # transpose 94 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 95 | 96 | # Helpers. 97 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 98 | def calc_output_padding(input_shape, output_shape): 99 | if transpose: 100 | return [0, 0] 101 | return [ 102 | input_shape[i + 2] 103 | - (output_shape[i + 2] - 1) * stride[i] 104 | - (1 - 2 * padding[i]) 105 | - dilation[i] * (weight_shape[i + 2] - 1) 106 | for i in range(ndim) 107 | ] 108 | 109 | # Forward & backward. 110 | class Conv2d(torch.autograd.Function): 111 | @staticmethod 112 | def forward(ctx, input, weight, bias): 113 | assert weight.shape == weight_shape 114 | ctx.save_for_backward( 115 | input if weight.requires_grad else _null_tensor, 116 | weight if input.requires_grad else _null_tensor, 117 | ) 118 | ctx.input_shape = input.shape 119 | 120 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). 121 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): 122 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) 123 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) 124 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) 125 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) 126 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 127 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 128 | 129 | # General case => cuDNN. 130 | if transpose: 131 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 132 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 133 | 134 | @staticmethod 135 | def backward(ctx, grad_output): 136 | input, weight = ctx.saved_tensors 137 | input_shape = ctx.input_shape 138 | grad_input = None 139 | grad_weight = None 140 | grad_bias = None 141 | 142 | if ctx.needs_input_grad[0]: 143 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) 144 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 145 | grad_input = op.apply(grad_output, weight, None) 146 | assert grad_input.shape == input_shape 147 | 148 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 149 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 150 | assert grad_weight.shape == weight_shape 151 | 152 | if ctx.needs_input_grad[2]: 153 | grad_bias = grad_output.sum([0, 2, 3]) 154 | 155 | return grad_input, grad_weight, grad_bias 156 | 157 | # Gradient with respect to the weights. 158 | class Conv2dGradWeight(torch.autograd.Function): 159 | @staticmethod 160 | def forward(ctx, grad_output, input): 161 | ctx.save_for_backward( 162 | grad_output if input.requires_grad else _null_tensor, 163 | input if grad_output.requires_grad else _null_tensor, 164 | ) 165 | ctx.grad_output_shape = grad_output.shape 166 | ctx.input_shape = input.shape 167 | 168 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). 169 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): 170 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 171 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 172 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) 173 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 174 | 175 | # General case => cuDNN. 176 | name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' 177 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 178 | return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 179 | 180 | @staticmethod 181 | def backward(ctx, grad2_grad_weight): 182 | grad_output, input = ctx.saved_tensors 183 | grad_output_shape = ctx.grad_output_shape 184 | input_shape = ctx.input_shape 185 | grad2_grad_output = None 186 | grad2_input = None 187 | 188 | if ctx.needs_input_grad[0]: 189 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 190 | assert grad2_grad_output.shape == grad_output_shape 191 | 192 | if ctx.needs_input_grad[1]: 193 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) 194 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 195 | grad2_input = op.apply(grad_output, grad2_grad_weight, None) 196 | assert grad2_input.shape == input_shape 197 | 198 | return grad2_grad_output, grad2_input 199 | 200 | _conv2d_gradfix_cache[key] = Conv2d 201 | return Conv2d 202 | 203 | #---------------------------------------------------------------------------- 204 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | from pkg_resources import parse_version 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | if _use_pytorch_1_11_api: 62 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) 63 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) 64 | else: 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /sketch/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /sketch/torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /sketch/torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ -------------------------------------------------------------------------------- /utils/alignment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import PIL 3 | import PIL.Image 4 | import scipy 5 | import scipy.ndimage 6 | import dlib 7 | 8 | 9 | def get_landmark(filepath, predictor): 10 | """get landmark with dlib 11 | :return: np.array shape=(68, 2) 12 | """ 13 | detector = dlib.get_frontal_face_detector() 14 | 15 | img = dlib.load_rgb_image(filepath) 16 | dets = detector(img, 1) 17 | 18 | for k, d in enumerate(dets): 19 | shape = predictor(img, d) 20 | 21 | t = list(shape.parts()) 22 | a = [] 23 | for tt in t: 24 | a.append([tt.x, tt.y]) 25 | lm = np.array(a) 26 | return lm 27 | 28 | 29 | def align_face(filepath, predictor): 30 | """ 31 | :param filepath: str 32 | :return: PIL Image 33 | """ 34 | 35 | lm = get_landmark(filepath, predictor) 36 | 37 | lm_chin = lm[0: 17] # left-right 38 | lm_eyebrow_left = lm[17: 22] # left-right 39 | lm_eyebrow_right = lm[22: 27] # left-right 40 | lm_nose = lm[27: 31] # top-down 41 | lm_nostrils = lm[31: 36] # top-down 42 | lm_eye_left = lm[36: 42] # left-clockwise 43 | lm_eye_right = lm[42: 48] # left-clockwise 44 | lm_mouth_outer = lm[48: 60] # left-clockwise 45 | lm_mouth_inner = lm[60: 68] # left-clockwise 46 | 47 | # Calculate auxiliary vectors. 48 | eye_left = np.mean(lm_eye_left, axis=0) 49 | eye_right = np.mean(lm_eye_right, axis=0) 50 | eye_avg = (eye_left + eye_right) * 0.5 51 | eye_to_eye = eye_right - eye_left 52 | mouth_left = lm_mouth_outer[0] 53 | mouth_right = lm_mouth_outer[6] 54 | mouth_avg = (mouth_left + mouth_right) * 0.5 55 | eye_to_mouth = mouth_avg - eye_avg 56 | 57 | # Choose oriented crop rectangle. 58 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 59 | x /= np.hypot(*x) 60 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 61 | y = np.flipud(x) * [-1, 1] 62 | c = eye_avg + eye_to_mouth * 0.1 63 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 64 | qsize = np.hypot(*x) * 2 65 | 66 | # read image 67 | img = PIL.Image.open(filepath) 68 | 69 | output_size = 256 70 | transform_size = 256 71 | enable_padding = True 72 | 73 | # Shrink. 74 | shrink = int(np.floor(qsize / output_size * 0.5)) 75 | if shrink > 1: 76 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 77 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 78 | quad /= shrink 79 | qsize /= shrink 80 | 81 | # Crop. 82 | border = max(int(np.rint(qsize * 0.1)), 3) 83 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 84 | int(np.ceil(max(quad[:, 1])))) 85 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), 86 | min(crop[3] + border, img.size[1])) 87 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 88 | img = img.crop(crop) 89 | quad -= crop[0:2] 90 | 91 | # Pad. 92 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 93 | int(np.ceil(max(quad[:, 1])))) 94 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), 95 | max(pad[3] - img.size[1] + border, 0)) 96 | if enable_padding and max(pad) > border - 4: 97 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 98 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 99 | h, w, _ = img.shape 100 | y, x, _ = np.ogrid[:h, :w, :1] 101 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 102 | 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) 103 | blur = qsize * 0.02 104 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 105 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 106 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 107 | quad += pad[:2] 108 | 109 | # Transform. 110 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 111 | if output_size < transform_size: 112 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 113 | 114 | # Return aligned image. 115 | return img -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | # Log images 6 | def log_input_image(x, opts): 7 | return tensor2im(x) 8 | 9 | 10 | def tensor2im(var): 11 | # var shape: (3, H, W) 12 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 13 | var = ((var + 1) / 2) 14 | var[var < 0] = 0 15 | var[var > 1] = 1 16 | var = var * 255 17 | return Image.fromarray(var.astype('uint8')) 18 | 19 | 20 | def vis_faces(log_hooks): 21 | display_count = len(log_hooks) 22 | fig = plt.figure(figsize=(8, 4 * display_count)) 23 | gs = fig.add_gridspec(display_count, 3) 24 | for i in range(display_count): 25 | hooks_dict = log_hooks[i] 26 | fig.add_subplot(gs[i, 0]) 27 | if 'diff_input' in hooks_dict: 28 | vis_faces_with_id(hooks_dict, fig, gs, i) 29 | else: 30 | vis_faces_no_id(hooks_dict, fig, gs, i) 31 | plt.tight_layout() 32 | return fig 33 | 34 | 35 | def vis_faces_with_id(hooks_dict, fig, gs, i): 36 | plt.imshow(hooks_dict['input_face']) 37 | plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) 38 | fig.add_subplot(gs[i, 1]) 39 | plt.imshow(hooks_dict['target_face']) 40 | plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), 41 | float(hooks_dict['diff_target']))) 42 | fig.add_subplot(gs[i, 2]) 43 | plt.imshow(hooks_dict['output_face']) 44 | plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target']))) 45 | 46 | 47 | def vis_faces_no_id(hooks_dict, fig, gs, i): 48 | plt.imshow(hooks_dict['input_face'], cmap="gray") 49 | plt.title('Input') 50 | fig.add_subplot(gs[i, 1]) 51 | plt.imshow(hooks_dict['target_face']) 52 | plt.title('Target') 53 | fig.add_subplot(gs[i, 2]) 54 | plt.imshow(hooks_dict['output_face']) 55 | plt.title('Output') -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adopted from pix2pixHD: 3 | https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py 4 | """ 5 | import os 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | class MultiResolutiontrainData(Dataset): 10 | 11 | def __init__(self, X_data, y_data): 12 | """ 13 | x_data : list(features) 14 | y_data : sketch 15 | """ 16 | 17 | X_data_ = [] 18 | for x_data in X_data: 19 | tmp = [torch.from_numpy(x)[0] for x in x_data] 20 | X_data_.append(tmp) 21 | self.X_data = X_data_ # List( List (Torch.tensor)) 22 | self.y_data = y_data # Torch.tensor: shape (h,w) 23 | 24 | #print("[dataloader] y_data shape: ", self.y_data[0].shape) 25 | 26 | assert len(self.X_data) == len(self.y_data), "dataset len does not match" 27 | 28 | def __getitem__(self, index): 29 | return self.X_data[index], self.y_data[index] 30 | 31 | def __len__(self): 32 | return len(self.X_data) 33 | 34 | IMG_EXTENSIONS = [ 35 | '.jpg', '.JPG', '.jpeg', '.JPEG', 36 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 37 | ] 38 | 39 | 40 | def is_image_file(filename): 41 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 42 | 43 | 44 | def make_dataset(dir): 45 | images = [] 46 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 47 | for root, _, fnames in sorted(os.walk(dir)): 48 | for fname in fnames: 49 | if is_image_file(fname): 50 | path = os.path.join(root, fname) 51 | images.append(path) 52 | return images -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | import pickle 24 | 25 | import torch 26 | from torch import distributed as dist 27 | 28 | 29 | def get_rank(): 30 | if not dist.is_available(): 31 | return 0 32 | 33 | if not dist.is_initialized(): 34 | return 0 35 | 36 | return dist.get_rank() 37 | 38 | 39 | def synchronize(): 40 | if not dist.is_available(): 41 | return 42 | 43 | if not dist.is_initialized(): 44 | return 45 | 46 | world_size = dist.get_world_size() 47 | 48 | if world_size == 1: 49 | return 50 | 51 | dist.barrier() 52 | 53 | 54 | def get_world_size(): 55 | if not dist.is_available(): 56 | return 1 57 | 58 | if not dist.is_initialized(): 59 | return 1 60 | 61 | return dist.get_world_size() 62 | 63 | 64 | def reduce_sum(tensor): 65 | if not dist.is_available(): 66 | return tensor 67 | 68 | if not dist.is_initialized(): 69 | return tensor 70 | 71 | tensor = tensor.clone() 72 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 73 | 74 | return tensor 75 | 76 | 77 | def gather_grad(params): 78 | world_size = get_world_size() 79 | 80 | if world_size == 1: 81 | return 82 | 83 | for param in params: 84 | if param.grad is not None: 85 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 86 | param.grad.data.div_(world_size) 87 | 88 | 89 | def all_gather(data): 90 | world_size = get_world_size() 91 | 92 | if world_size == 1: 93 | return [data] 94 | 95 | buffer = pickle.dumps(data) 96 | storage = torch.ByteStorage.from_buffer(buffer) 97 | tensor = torch.ByteTensor(storage).to('cuda') 98 | 99 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 100 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 101 | dist.all_gather(size_list, local_size) 102 | size_list = [int(size.item()) for size in size_list] 103 | max_size = max(size_list) 104 | 105 | tensor_list = [] 106 | for _ in size_list: 107 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 108 | 109 | if local_size != max_size: 110 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 111 | tensor = torch.cat((tensor, padding), 0) 112 | 113 | dist.all_gather(tensor_list, tensor) 114 | 115 | data_list = [] 116 | 117 | for size, tensor in zip(size_list, tensor_list): 118 | buffer = tensor.cpu().numpy().tobytes()[:size] 119 | data_list.append(pickle.loads(buffer)) 120 | 121 | return data_list 122 | 123 | 124 | def reduce_loss_dict(loss_dict): 125 | world_size = get_world_size() 126 | 127 | if world_size < 2: 128 | return loss_dict 129 | 130 | with torch.no_grad(): 131 | keys = [] 132 | losses = [] 133 | 134 | for k in sorted(loss_dict.keys()): 135 | keys.append(k) 136 | losses.append(loss_dict[k]) 137 | 138 | losses = torch.stack(losses, 0) 139 | dist.reduce(losses, dst=0) 140 | 141 | if dist.get_rank() == 0: 142 | losses /= world_size 143 | 144 | reduced_losses = {k: v for k, v in zip(keys, losses)} 145 | 146 | return reduced_losses 147 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from models.psp import pSp 4 | from models.encoders.psp_encoders import Encoder4Editing 5 | 6 | 7 | def setup_model(checkpoint_path, device='cuda'): 8 | ckpt = torch.load(checkpoint_path, map_location='cpu') 9 | opts = ckpt['opts'] 10 | 11 | opts['checkpoint_path'] = checkpoint_path 12 | opts['device'] = device 13 | opts = argparse.Namespace(**opts) 14 | 15 | net = pSp(opts) 16 | net.eval() 17 | net = net.to(device) 18 | return net, opts 19 | 20 | 21 | def load_e4e_standalone(checkpoint_path, device='cuda'): 22 | ckpt = torch.load(checkpoint_path, map_location='cpu') 23 | opts = argparse.Namespace(**ckpt['opts']) 24 | e4e = Encoder4Editing(50, 'ir_se', opts) 25 | e4e_dict = {k.replace('encoder.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('encoder.')} 26 | e4e.load_state_dict(e4e_dict,strict = False) 27 | e4e.eval() 28 | e4e = e4e.to(device) 29 | latent_avg = ckpt['latent_avg'].to(device) 30 | 31 | def add_latent_avg(model, inputs, outputs): 32 | return outputs + latent_avg.repeat(outputs.shape[0], 1, 1) 33 | 34 | e4e.register_forward_hook(add_latent_avg) 35 | return e4e -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | import torch 24 | from PIL import Image 25 | import numpy as np 26 | from torch import nn 27 | import math 28 | from torch.backends import cudnn 29 | import random 30 | 31 | class Interpolate(nn.Module): 32 | def __init__(self, size, mode): 33 | super(Interpolate, self).__init__() 34 | self.interp = nn.functional.interpolate 35 | self.size = size 36 | self.mode = mode 37 | 38 | def forward(self, x): 39 | x = self.interp(x, size=self.size, mode=self.mode, align_corners=False) 40 | return x 41 | 42 | 43 | 44 | def multi_acc(y_pred, y_test): 45 | y_pred_softmax = torch.log_softmax(y_pred, dim=1) 46 | _, y_pred_tags = torch.max(y_pred_softmax, dim=1) 47 | 48 | correct_pred = (y_pred_tags == y_test).float() 49 | acc = correct_pred.sum() / len(correct_pred) 50 | 51 | acc = acc * 100 52 | 53 | return acc 54 | 55 | 56 | def oht_to_scalar(y_pred): 57 | y_pred_softmax = torch.log_softmax(y_pred, dim=1) 58 | _, y_pred_tags = torch.max(y_pred_softmax, dim=1) 59 | 60 | return y_pred_tags 61 | 62 | def oht_to_scalar_binary(y_pred): 63 | y_pred_tags=[] 64 | for i in range(len(y_pred)): 65 | for pixel in y_pred[i]: 66 | if(pixel>=0.5): 67 | y_pred_tags.append(1) 68 | else: 69 | y_pred_tags.append(0) 70 | y_pred_tags=torch.FloatTensor(y_pred_tags) 71 | return y_pred_tags 72 | 73 | def oht_to_scalar_regression(y_pred): 74 | y_pred_tags=(y_pred*255).type(torch.int) 75 | return y_pred_tags 76 | 77 | def latent_to_image(g_all, latents, stylegan_version, use_style_latents=False, 78 | style_latents=None, process_out=True, return_stylegan_latent=False, dim=512): 79 | '''Given a input latent code, generate corresponding image and concatenated feature maps''' 80 | 81 | device = torch.device('cuda') 82 | if use_style_latents: 83 | style_latents=latents 84 | else : 85 | if stylegan_version==1 : 86 | style_latents = g_all.truncation(g_all.g_mapping(latents)) 87 | style_latents = style_latents.clone() 88 | 89 | elif stylegan_version==2 : 90 | label = torch.zeros([1, 0 ],device=device) 91 | style_latents=g_all.g_mapping(latents,label,truncation_psi=0.7) 92 | 93 | if return_stylegan_latent: 94 | return style_latents 95 | 96 | 97 | img_list, affine_layers = g_all.g_synthesis(style_latents) 98 | 99 | 100 | number_feautre = 0 101 | 102 | affine_layers_upsamples=[] 103 | for item in affine_layers: 104 | the_item = item.detach().cpu().numpy() 105 | affine_layers_upsamples.append(the_item) 106 | 107 | if process_out: 108 | img_list = img_list.cpu().detach().numpy() 109 | img_list = process_image(img_list) 110 | img_list = np.transpose(img_list, (0, 2, 3, 1)).astype(np.uint8) 111 | 112 | return img_list, affine_layers_upsamples 113 | 114 | def in_size(value,imsize=1024): 115 | if value>imsize-129: 116 | value =imsize-129 117 | if value<0: 118 | value=0 119 | return value 120 | 121 | def process_image(images): 122 | drange = [-1, 1] 123 | scale = 255 / (drange[1] - drange[0]) 124 | images = images * scale + (0.5 - drange[0] * scale) 125 | 126 | images = images.astype(int) 127 | images[images > 255] = 255 128 | images[images < 0] = 0 129 | 130 | return images.astype(int) 131 | 132 | def colorize_mask(mask, palette): 133 | # mask: numpy array of the mask 134 | 135 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 136 | new_mask.putpalette(palette) 137 | return np.array(new_mask.convert('RGB')) 138 | 139 | 140 | def get_label_stas(data_loader): 141 | count_dict = {} 142 | for i in range(data_loader.__len__()): 143 | x, y = data_loader.__getitem__(i) 144 | if int(y.item()) not in count_dict: 145 | count_dict[int(y.item())] = 1 146 | else: 147 | count_dict[int(y.item())] += 1 148 | 149 | return count_dict 150 | 151 | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): 152 | lr_ramp = min(1, (1 - t) / rampdown) 153 | lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) 154 | lr_ramp = lr_ramp * min(1, t / rampup) 155 | 156 | return initial_lr * lr_ramp 157 | 158 | def set_seed(seed): 159 | cudnn.benchmark = False # if benchmark=True, deterministic will be False 160 | cudnn.deterministic = True 161 | np.random.seed(seed) 162 | random.seed(seed) 163 | torch.manual_seed(seed) 164 | torch.cuda.manual_seed(seed) 165 | torch.cuda.manual_seed_all(seed) --------------------------------------------------------------------------------