├── .DS_Store ├── data ├── .DS_Store ├── CelebAMask-HQ-Sample.zip └── hed_edge_256-Sample.zip ├── doc ├── .DS_Store └── coverpage.jpg ├── .gitattributes ├── op ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── fused_act.cpython-37.pyc ├── fused_bias_act.cpp ├── upfirdn2d.cpp ├── fused_bias_act_kernel.cu ├── fused_act.py ├── upfirdn2d.py ├── conv2d_gradfix.py └── upfirdn2d_kernel.cu ├── LICENSE-sbarratt ├── LICENSE ├── LICENSE-rosinality ├── LICENSE-eriklindernoren ├── LICENSE-NVIDIA ├── README.md ├── dataset.py ├── non_leaking.py ├── train.py └── model.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yan98/S2FGAN/HEAD/.DS_Store -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yan98/S2FGAN/HEAD/data/.DS_Store -------------------------------------------------------------------------------- /doc/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yan98/S2FGAN/HEAD/doc/.DS_Store -------------------------------------------------------------------------------- /doc/coverpage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yan98/S2FGAN/HEAD/doc/coverpage.jpg -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /data/CelebAMask-HQ-Sample.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yan98/S2FGAN/HEAD/data/CelebAMask-HQ-Sample.zip -------------------------------------------------------------------------------- /data/hed_edge_256-Sample.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yan98/S2FGAN/HEAD/data/hed_edge_256-Sample.zip -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /op/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yan98/S2FGAN/HEAD/op/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /op/__pycache__/fused_act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yan98/S2FGAN/HEAD/op/__pycache__/fused_act.cpython-37.pyc -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /LICENSE-sbarratt: -------------------------------------------------------------------------------- 1 | Copyright 2017 Shane T. Barratt 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yan Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE-rosinality: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /LICENSE-eriklindernoren: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Erik Linder-Norén 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | 61 | ctx.bias = bias is not None 62 | 63 | if bias is None: 64 | bias = empty 65 | 66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 67 | ctx.save_for_backward(out) 68 | ctx.negative_slope = negative_slope 69 | ctx.scale = scale 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | out, = ctx.saved_tensors 76 | 77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 79 | ) 80 | 81 | if not ctx.bias: 82 | grad_bias = None 83 | 84 | return grad_input, grad_bias, None, None 85 | 86 | 87 | class FusedLeakyReLU(nn.Module): 88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 89 | super().__init__() 90 | 91 | if bias: 92 | self.bias = nn.Parameter(torch.zeros(channel)) 93 | 94 | else: 95 | self.bias = None 96 | 97 | self.negative_slope = negative_slope 98 | self.scale = scale 99 | 100 | def forward(self, input): 101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 102 | 103 | 104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 105 | if input.device.type == "cpu": 106 | if bias is not None: 107 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 108 | return ( 109 | F.leaky_relu( 110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 111 | ) 112 | * scale 113 | ) 114 | 115 | else: 116 | return F.leaky_relu(input, negative_slope=0.2) * scale 117 | 118 | else: 119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 120 | -------------------------------------------------------------------------------- /LICENSE-NVIDIA: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | Nvidia Source Code License-NC 5 | 6 | ======================================================================= 7 | 8 | 1. Definitions 9 | 10 | "Licensor" means any person or entity that distributes its Work. 11 | 12 | "Software" means the original work of authorship made available under 13 | this License. 14 | 15 | "Work" means the Software and any additions to or derivative works of 16 | the Software that are made available under this License. 17 | 18 | "Nvidia Processors" means any central processing unit (CPU), graphics 19 | processing unit (GPU), field-programmable gate array (FPGA), 20 | application-specific integrated circuit (ASIC) or any combination 21 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 22 | 23 | The terms "reproduce," "reproduction," "derivative works," and 24 | "distribution" have the meaning as provided under U.S. copyright law; 25 | provided, however, that for the purposes of this License, derivative 26 | works shall not include works that remain separable from, or merely 27 | link (or bind by name) to the interfaces of, the Work. 28 | 29 | Works, including the Software, are "made available" under this License 30 | by including in or with the Work either (a) a copyright notice 31 | referencing the applicability of this License to the Work, or (b) a 32 | copy of this License. 33 | 34 | 2. License Grants 35 | 36 | 2.1 Copyright Grant. Subject to the terms and conditions of this 37 | License, each Licensor grants to you a perpetual, worldwide, 38 | non-exclusive, royalty-free, copyright license to reproduce, 39 | prepare derivative works of, publicly display, publicly perform, 40 | sublicense and distribute its Work and any resulting derivative 41 | works in any form. 42 | 43 | 3. Limitations 44 | 45 | 3.1 Redistribution. You may reproduce or distribute the Work only 46 | if (a) you do so under this License, (b) you include a complete 47 | copy of this License with your distribution, and (c) you retain 48 | without modification any copyright, patent, trademark, or 49 | attribution notices that are present in the Work. 50 | 51 | 3.2 Derivative Works. You may specify that additional or different 52 | terms apply to the use, reproduction, and distribution of your 53 | derivative works of the Work ("Your Terms") only if (a) Your Terms 54 | provide that the use limitation in Section 3.3 applies to your 55 | derivative works, and (b) you identify the specific derivative 56 | works that are subject to Your Terms. Notwithstanding Your Terms, 57 | this License (including the redistribution requirements in Section 58 | 3.1) will continue to apply to the Work itself. 59 | 60 | 3.3 Use Limitation. The Work and any derivative works thereof only 61 | may be used or intended for use non-commercially. The Work or 62 | derivative works thereof may be used or intended for use by Nvidia 63 | or its affiliates commercially or non-commercially. As used herein, 64 | "non-commercially" means for research or evaluation purposes only. 65 | 66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 67 | against any Licensor (including any claim, cross-claim or 68 | counterclaim in a lawsuit) to enforce any patents that you allege 69 | are infringed by any Work, then your rights under this License from 70 | such Licensor (including the grants in Sections 2.1 and 2.2) will 71 | terminate immediately. 72 | 73 | 3.5 Trademarks. This License does not grant any rights to use any 74 | Licensor's or its affiliates' names, logos, or trademarks, except 75 | as necessary to reproduce the notices described in this License. 76 | 77 | 3.6 Termination. If you violate any term of this License, then your 78 | rights under this License (including the grants in Sections 2.1 and 79 | 2.2) will terminate immediately. 80 | 81 | 4. Disclaimer of Warranty. 82 | 83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 87 | THIS LICENSE. 88 | 89 | 5. Limitation of Liability. 90 | 91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 99 | THE POSSIBILITY OF SUCH DAMAGES. 100 | 101 | ======================================================================= -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # S2FGAN-pytorch Implementation 2 | ![](doc/coverpage.jpg) 3 | 4 | ## Dependency 5 | * python 3.7.4 6 | * numpy 1.18.1 7 | * Pillow 7.0.0 8 | * opencv-python 4.2.0.32 9 | * torch 1.5.1 10 | * torchvision 0.5.0 11 | * albumentations 0.4.6 12 | * cudnn 7.6.5 13 | * CUDA 10.1 14 | 15 | At least a single GPU is needed. Please install the library with CUDA and C++ in Linux system. 16 | 17 | ## Dataset 18 | * Obtain [CeleAMask-HQ dataset](https://github.com/switchablenorms/CelebAMask-HQ). 19 | * Download pertained HED model from https://github.com/s9xie/hed. 20 | * Extract sketch (Hed Edge) using the scripts in https://www.pyimagesearch.com/2019/03/04/holistically-nested-edge-detection-with-opencv-and-deep-learning/. 21 | * Post-process the sketch by using the method indicated by Isola. Note: Matlab Required. Use `PostprocessHED.m` from their github. Paper Name: Image-to-Image Translation with Conditional Adversarial Networks. Github link:https://github.com/phillipi/pix2pix/tree/master/scripts/edges 22 | * Zip the post-processed sketch. 23 | 24 | ## Notice 25 | * S2FGAN is only tested in 256x256 resolution. 26 | * We refactor the current implementation from our original training code. If you find any implementation error, please do not hesitate to contact us. 27 | 28 | ## Train S2FGAN 29 | 30 | * The validation images will be saved in `sample` folder, the model checkpoints will be saved in `checkpoint`, the training log will be written in `log.txt`. 31 | 32 | * For training, please run `train.py`, while set the parameters properly. 33 | 34 | ```bash 35 | python train.py --help 36 | 37 | --iter #total training iterations 38 | --batch #batch size 39 | --r1 #weight of the r1 regularization 40 | --d_reg_every #interval of the applying r1 regularization to discriminator 41 | --lr #learning rate 42 | --augment #apply discriminator augmentation 43 | --augment_p #probability of applying discriminator augmentation. 0 = use adaptive augmentation 44 | --ada_target #target augmentation probability for adaptive augmentation 45 | --ada_length #target duraing to reach augmentation probability for adaptive augmentation 46 | --ada_every #probability update interval of the adaptive augmentation 47 | --img_height #image height 48 | --img_width #image width 49 | --NumberOfImage #The number of images in the zip. 50 | --imageZip #input image zip 51 | --hedEdgeZip #hed sketch zip 52 | --hedEdgePath #hed_edge_256 53 | --imagePath #path of images in the zip 54 | --TORCH_HOME #The director store pertained pytorch model, "None" will load the pertained model from default director. 55 | --label_path #attributes annotation text file of CelebAMask-HQ 56 | --selected_attrs #selected attributes for the CelebAMask-HQ dataset 57 | --ATMDTT #Attributes to manipulate during testing time 58 | --model_type #0- S2F-DIS, 1- S2F-DEC 59 | ``` 60 | 61 | * Train on S2F-DIS 62 | 63 | ```bash 64 | python3 train.py --model_type 0 #Please set data path properly. 65 | ``` 66 | 67 | * Train on S2F-DEC 68 | 69 | ```bash 70 | python3 train.py --model_type 1 #Please set data path properly. 71 | ``` 72 | 73 | ## Code for Related Work 74 | * AttGAN: https://github.com/elvisyjlin/AttGAN-PyTorch. 75 | * STGAN: https://github.com/bluestyle97/STGAN-pytorch. 76 | * Pix2PixHD: https://github.com/NVIDIA/pix2pixHD 77 | * DFP: https://github.com/LiYuhangUSTC/Sketch2Face 78 | * DFD: https://github.com/IGLICT/DeepFaceDrawing-Jittor 79 | * DPS: https://github.com/VITA-Group/DeepPS 80 | 81 | ## Evaluation metrics 82 | * FID Score: https://github.com/mseitzer/pytorch-fid. 83 | * IS Score: https://github.com/sbarratt/inception-score-pytorch. 84 | * Evaluation Classier: https://github.com/csmliu/STGAN. 85 | * Simulate Badly Drawn Sketches: https://github.com/VITA-Group/DeepPS/blob/master/src/roughSketchSyn.py 86 | 87 | ## Todo 88 | 89 | - [ ] Upload pretrained checkpoints 90 | - [ ] Upload testing script 91 | 92 | If you are urgent to use the checkpoint, please drop [me](mailto:yan.yang@anu.edu.au?subject=[GitHub]S2FGAN) an email. 93 | 94 | ## License 95 | The Equalized layer, Modulated layer, PixelNorm and CUDA kernels are from offical styleGAN. For more details, please refer to repostiories: https://github.com/NVlabs/stylegan2 96 | 97 | Thanks for Rosinality's StyleGAN pytorch implementation. The S2FGAN builds based on this template: https://github.com/rosinality/stylegan2-pytorch. 98 | 99 | The dataloader is based on eriklindernoren's repostiories: https://github.com/eriklindernoren/PyTorch-GAN 100 | 101 | The AttGAN can be find in https://github.com/elvisyjlin/AttGAN-PyTorch 102 | 103 | Data prefetcher is based on the implementation from NVIDIA Apex: https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py#L256 104 | 105 | The HED detector loading and Crop layer implementation is from Rosebrock: https://www.pyimagesearch.com/2019/03/04/holistically-nested-edge-detection-with-opencv-and-deep-learning/ 106 | 107 | ## Demo Video for Attribute Editing - Click to Play 108 | [![IMAGE ALT TEXT](http://img.youtube.com/vi/nd3Gq2lV_Do/0.jpg)](http://www.youtube.com/watch?v=nd3Gq2lV_Do "S2FGAN Demo") 109 | 110 | ## Citation 111 | If you find [S2FGAN](https://arxiv.org/abs/2011.14785) useful in your research work, please consider citing: 112 | ```bibtex 113 | @ARTICLE{s2fgan, 114 | author = {Yang, Yan and Hossain, Md Zakir and Gedeon, Tom and Rahman, Shafin}, 115 | year = {2020}, 116 | month = {11}, 117 | pages = {}, 118 | title = {S2FGAN: Semantically Aware Interactive Sketch-to-Face Translation} 119 | } 120 | ``` 121 | 122 | 123 | -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | This module is the concrete implementation of pytorch dataset. 5 | The module structure is the following: 6 | - extract_zip is used to read the image from the zip into memory 7 | - read_image_from_zip parse the file to PIL image 8 | - CelebDataset is the wrapper of CelebAMASK-HQ dataset. 9 | """ 10 | 11 | import torch.utils.data as data 12 | from PIL import Image 13 | from io import BytesIO 14 | import albumentations as A 15 | import numpy as np 16 | import torch 17 | import zipfile 18 | 19 | def extract_zip(input_zip): 20 | ''' 21 | Parameters 22 | ---------- 23 | input_zip : zipfile, the zipfile need to be read in memory 24 | Returns 25 | ------- 26 | dict: a dictionary maps the path to ".jpg" or ".png" image in the zipfile 27 | ''' 28 | 29 | input_zip=zipfile.ZipFile(input_zip) 30 | return {name: input_zip.read(name) for name in input_zip.namelist() if name.endswith(".jpg") or name.endswith(".png")} 31 | 32 | def read_image_from_zip(file,path,height = None,width = None): 33 | """ 34 | Parameters 35 | ---------- 36 | file: zipfile, the zipfile need to be read 37 | path: str, the path to read in the file. 38 | height: int, the height of the image desired 39 | width: int, the width of the image desired. 40 | 41 | Returns 42 | ------- 43 | img: a PIL image with desired height and width 44 | """ 45 | 46 | img = Image.open(BytesIO(file[path])) 47 | 48 | if height != None and width != None: 49 | img = img.resize((height,width)) 50 | 51 | return img 52 | 53 | class CeleDataset(data.Dataset): 54 | 55 | ''' 56 | The pytorch dataset wrapper for the CelebAMASK-HQ dataset. 57 | 58 | ''' 59 | 60 | def __init__(self,params,train = True): 61 | 62 | """ 63 | Return, None 64 | Parameters 65 | ---------- 66 | params: A parser file which contains the parameters for the class. 67 | train: boolean, decide if the class is used for trainning. 68 | 69 | Returns 70 | ------- 71 | None 72 | """ 73 | 74 | global selected_attrs,label_path 75 | 76 | selected_attrs = params.selected_attrs 77 | label_path = params.label_path 78 | 79 | self.params = params 80 | self.image_zip = extract_zip(params.imageZip) 81 | self.hedZip = extract_zip(params.hedEdgeZip) 82 | self.indexToPath = self.generate_path(train) 83 | self.att = self.get_annotations() 84 | self.train = train 85 | self.aug = A.Compose({ 86 | A.RandomSizedCrop(min_max_height = (int(self.params.img_height * 0.8),self.params.img_height),height = self.params.img_height,width = self.params.img_width, p = 0.5), 87 | A.HorizontalFlip(p=0.5) 88 | }) 89 | 90 | def get_annotations(self): 91 | """ 92 | Return, A dict contains the attributes of interest. 93 | Parameters 94 | ---------- 95 | None 96 | 97 | Returns 98 | ------- 99 | annotations, dict, read the selected attributes, and store it in the annoations. 100 | """ 101 | 102 | annotations = {} 103 | lines = [line.rstrip() for line in open(label_path, "r")] 104 | self.label_names = lines[1].split() 105 | for _, line in enumerate(lines[2:]): 106 | filename, *values = line.split() 107 | labels = [] 108 | for attr in selected_attrs: 109 | idx = self.label_names.index(attr) 110 | labels.append((1 if (values[idx] == "1") else 0)) 111 | annotations[filename.replace(".jpg",".png")] = labels 112 | return annotations 113 | 114 | def generate_path(self,train): 115 | 116 | """ 117 | Return, A dict that mapps integers to files. 118 | Parameters 119 | ---------- 120 | train, bool, decide which files to read. Training and testing will lead reading diffirent files 121 | 122 | Returns 123 | ------- 124 | selected_index_ToPath, dict, the dictionary contains the mapping of integer and files 125 | """ 126 | 127 | indexToPath = dict() 128 | index = 0 129 | for file in range(self.params.NumOfImage): 130 | file = str(file) 131 | file += ".png" 132 | indexToPath[index] = [ 133 | self.params.imagePath + "/" + file.replace(".png",".jpg"), 134 | file 135 | ] 136 | index += 1 137 | 138 | selected_indexToPath = dict() 139 | new_index = 0 140 | for k, value in indexToPath.items(): 141 | 142 | if not train: 143 | if k % 20 == 0: 144 | selected_indexToPath[new_index] = value 145 | new_index+=1 146 | else: 147 | if k % 20 != 0: 148 | selected_indexToPath[new_index] = value 149 | new_index+=1 150 | 151 | return selected_indexToPath 152 | 153 | 154 | def __getitem__(self, index): 155 | 156 | """ 157 | Return, sketch,img,label 158 | Parameters 159 | ---------- 160 | index: int, the index of the file need to be read 161 | 162 | Returns 163 | ------- 164 | sketch : pytorch float tensor, input sketch 165 | img : pytorch float tensor, the ground truth image corresponds to sketch. 166 | labels : pytorch float tensor, the attributes of the img 167 | """ 168 | 169 | #get path for image and sketch 170 | 171 | image_path, sketch_path = self.indexToPath[index] 172 | 173 | #read image into numpy array 174 | img = read_image_from_zip(self.image_zip,image_path,self.params.img_height,self.params.img_width) 175 | img = np.array(img) 176 | 177 | #read sketch into numpy array 178 | hed_edge = read_image_from_zip(self.hedZip,self.params.hedEdgePath + "/" + sketch_path.replace(".png",".jpg"),self.params.img_height,self.params.img_width) 179 | sketch = np.array(hed_edge) 180 | 181 | #augment sketch and image if in the training mode 182 | if self.train: 183 | augmented = self.aug(image = img,mask = sketch) 184 | img = augmented['image'] 185 | sketch = augmented['mask'] 186 | 187 | img = torch.FloatTensor(img).permute(2,0,1) 188 | sketch = torch.FloatTensor(sketch).unsqueeze(2).permute(2,0,1) 189 | 190 | #read labels into pytorch float tensor. 191 | 192 | label = self.att[sketch_path] 193 | label = torch.FloatTensor(np.array(label)) 194 | 195 | return sketch,img,label 196 | 197 | def __len__(self): 198 | 199 | """ 200 | Return, the number of ground truth images in the files 201 | Parameters 202 | ---------- 203 | None 204 | 205 | Returns 206 | ------- 207 | The number of ground truth images in the files 208 | """ 209 | 210 | return len(self.indexToPath) -------------------------------------------------------------------------------- /non_leaking.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | from op import upfirdn2d 7 | 8 | 9 | SYM6 = ( 10 | 0.015404109327027373, 11 | 0.0034907120842174702, 12 | -0.11799011114819057, 13 | -0.048311742585633, 14 | 0.4910559419267466, 15 | 0.787641141030194, 16 | 0.3379294217276218, 17 | -0.07263752278646252, 18 | -0.021060292512300564, 19 | 0.04472490177066578, 20 | 0.0017677118642428036, 21 | -0.007800708325034148, 22 | ) 23 | 24 | 25 | def translate_mat(t_x, t_y): 26 | batch = t_x.shape[0] 27 | 28 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 29 | translate = torch.stack((t_x, t_y), 1) 30 | mat[:, :2, 2] = translate 31 | 32 | return mat 33 | 34 | 35 | def rotate_mat(theta): 36 | batch = theta.shape[0] 37 | 38 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 39 | sin_t = torch.sin(theta) 40 | cos_t = torch.cos(theta) 41 | rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2) 42 | mat[:, :2, :2] = rot 43 | 44 | return mat 45 | 46 | 47 | def scale_mat(s_x, s_y): 48 | batch = s_x.shape[0] 49 | 50 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 51 | mat[:, 0, 0] = s_x 52 | mat[:, 1, 1] = s_y 53 | 54 | return mat 55 | 56 | 57 | def translate3d_mat(t_x, t_y, t_z): 58 | batch = t_x.shape[0] 59 | 60 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 61 | translate = torch.stack((t_x, t_y, t_z), 1) 62 | mat[:, :3, 3] = translate 63 | 64 | return mat 65 | 66 | 67 | def rotate3d_mat(axis, theta): 68 | batch = theta.shape[0] 69 | 70 | u_x, u_y, u_z = axis 71 | 72 | eye = torch.eye(3).unsqueeze(0) 73 | cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0) 74 | outer = torch.tensor(axis) 75 | outer = (outer.unsqueeze(1) * outer).unsqueeze(0) 76 | 77 | sin_t = torch.sin(theta).view(-1, 1, 1) 78 | cos_t = torch.cos(theta).view(-1, 1, 1) 79 | 80 | rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer 81 | 82 | eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 83 | eye_4[:, :3, :3] = rot 84 | 85 | return eye_4 86 | 87 | 88 | def scale3d_mat(s_x, s_y, s_z): 89 | batch = s_x.shape[0] 90 | 91 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 92 | mat[:, 0, 0] = s_x 93 | mat[:, 1, 1] = s_y 94 | mat[:, 2, 2] = s_z 95 | 96 | return mat 97 | 98 | 99 | def luma_flip_mat(axis, i): 100 | batch = i.shape[0] 101 | 102 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 103 | axis = torch.tensor(axis + (0,)) 104 | flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1) 105 | 106 | return eye - flip 107 | 108 | 109 | def saturation_mat(axis, i): 110 | batch = i.shape[0] 111 | 112 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 113 | axis = torch.tensor(axis + (0,)) 114 | axis = torch.ger(axis, axis) 115 | saturate = axis + (eye - axis) * i.view(-1, 1, 1) 116 | 117 | return saturate 118 | 119 | 120 | def lognormal_sample(size, mean=0, std=1): 121 | return torch.empty(size).log_normal_(mean=mean, std=std) 122 | 123 | 124 | def category_sample(size, categories): 125 | category = torch.tensor(categories) 126 | sample = torch.randint(high=len(categories), size=(size,)) 127 | 128 | return category[sample] 129 | 130 | 131 | def uniform_sample(size, low, high): 132 | return torch.empty(size).uniform_(low, high) 133 | 134 | 135 | def normal_sample(size, mean=0, std=1): 136 | return torch.empty(size).normal_(mean, std) 137 | 138 | 139 | def bernoulli_sample(size, p): 140 | return torch.empty(size).bernoulli_(p) 141 | 142 | 143 | def random_mat_apply(p, transform, prev, eye): 144 | size = transform.shape[0] 145 | select = bernoulli_sample(size, p).view(size, 1, 1) 146 | select_transform = select * transform + (1 - select) * eye 147 | 148 | return select_transform @ prev 149 | 150 | 151 | def sample_affine(p, size, height, width): 152 | G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1) 153 | eye = G 154 | 155 | # flip 156 | param = category_sample(size, (0, 1)) 157 | Gc = scale_mat(1 - 2.0 * param, torch.ones(size)) 158 | G = random_mat_apply(p, Gc, G, eye) 159 | # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n') 160 | 161 | # 90 rotate 162 | param = category_sample(size, (0, 3)) 163 | Gc = rotate_mat(-math.pi / 2 * param) 164 | G = random_mat_apply(p, Gc, G, eye) 165 | # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n') 166 | 167 | # integer translate 168 | param = uniform_sample(size, -0.125, 0.125) 169 | param_height = torch.round(param * height) / height 170 | param_width = torch.round(param * width) / width 171 | Gc = translate_mat(param_width, param_height) 172 | G = random_mat_apply(p, Gc, G, eye) 173 | # print('integer translate', G, translate_mat(param_width, param_height), sep='\n') 174 | 175 | # isotropic scale 176 | param = lognormal_sample(size, std=0.2 * math.log(2)) 177 | Gc = scale_mat(param, param) 178 | G = random_mat_apply(p, Gc, G, eye) 179 | # print('isotropic scale', G, scale_mat(param, param), sep='\n') 180 | 181 | p_rot = 1 - math.sqrt(1 - p) 182 | 183 | # pre-rotate 184 | param = uniform_sample(size, -math.pi, math.pi) 185 | Gc = rotate_mat(-param) 186 | G = random_mat_apply(p_rot, Gc, G, eye) 187 | # print('pre-rotate', G, rotate_mat(-param), sep='\n') 188 | 189 | # anisotropic scale 190 | param = lognormal_sample(size, std=0.2 * math.log(2)) 191 | Gc = scale_mat(param, 1 / param) 192 | G = random_mat_apply(p, Gc, G, eye) 193 | # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n') 194 | 195 | # post-rotate 196 | param = uniform_sample(size, -math.pi, math.pi) 197 | Gc = rotate_mat(-param) 198 | G = random_mat_apply(p_rot, Gc, G, eye) 199 | # print('post-rotate', G, rotate_mat(-param), sep='\n') 200 | 201 | # fractional translate 202 | param = normal_sample(size, std=0.125) 203 | Gc = translate_mat(param, param) 204 | G = random_mat_apply(p, Gc, G, eye) 205 | # print('fractional translate', G, translate_mat(param, param), sep='\n') 206 | 207 | return G 208 | 209 | 210 | def sample_color(p, size): 211 | C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1) 212 | eye = C 213 | axis_val = 1 / math.sqrt(3) 214 | axis = (axis_val, axis_val, axis_val) 215 | 216 | # brightness 217 | param = normal_sample(size, std=0.2) 218 | Cc = translate3d_mat(param, param, param) 219 | C = random_mat_apply(p, Cc, C, eye) 220 | 221 | # contrast 222 | param = lognormal_sample(size, std=0.5 * math.log(2)) 223 | Cc = scale3d_mat(param, param, param) 224 | C = random_mat_apply(p, Cc, C, eye) 225 | 226 | # luma flip 227 | param = category_sample(size, (0, 1)) 228 | Cc = luma_flip_mat(axis, param) 229 | C = random_mat_apply(p, Cc, C, eye) 230 | 231 | # hue rotation 232 | param = uniform_sample(size, -math.pi, math.pi) 233 | Cc = rotate3d_mat(axis, param) 234 | C = random_mat_apply(p, Cc, C, eye) 235 | 236 | # saturation 237 | param = lognormal_sample(size, std=1 * math.log(2)) 238 | Cc = saturation_mat(axis, param) 239 | C = random_mat_apply(p, Cc, C, eye) 240 | 241 | return C 242 | 243 | 244 | def make_grid(shape, x0, x1, y0, y1, device): 245 | n, c, h, w = shape 246 | grid = torch.empty(n, h, w, 3, device=device) 247 | grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device) 248 | grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1) 249 | grid[:, :, :, 2] = 1 250 | 251 | return grid 252 | 253 | 254 | def affine_grid(grid, mat): 255 | n, h, w, _ = grid.shape 256 | return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2) 257 | 258 | 259 | def get_padding(G, height, width): 260 | extreme = ( 261 | G[:, :2, :] 262 | @ torch.tensor([(-1.0, -1, 1), (-1, 1, 1), (1, -1, 1), (1, 1, 1)]).t() 263 | ) 264 | 265 | size = torch.tensor((width, height)) 266 | 267 | pad_low = ( 268 | ((extreme.min(-1).values + 1) * size) 269 | .clamp(max=0) 270 | .abs() 271 | .ceil() 272 | .max(0) 273 | .values.to(torch.int64) 274 | .tolist() 275 | ) 276 | pad_high = ( 277 | (extreme.max(-1).values * size - size) 278 | .clamp(min=0) 279 | .ceil() 280 | .max(0) 281 | .values.to(torch.int64) 282 | .tolist() 283 | ) 284 | 285 | return pad_low[0], pad_high[0], pad_low[1], pad_high[1] 286 | 287 | 288 | def try_sample_affine_and_pad(img, p, pad_k, G=None): 289 | batch, _, height, width = img.shape 290 | 291 | G_try = G 292 | 293 | while True: 294 | if G is None: 295 | G_try = sample_affine(p, batch, height, width) 296 | 297 | pad_x1, pad_x2, pad_y1, pad_y2 = get_padding( 298 | torch.inverse(G_try), height, width 299 | ) 300 | 301 | try: 302 | img_pad = F.pad( 303 | img, 304 | (pad_x1 + pad_k, pad_x2 + pad_k, pad_y1 + pad_k, pad_y2 + pad_k), 305 | mode="reflect", 306 | ) 307 | 308 | except RuntimeError: 309 | continue 310 | 311 | break 312 | 313 | return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2) 314 | 315 | 316 | def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6): 317 | kernel = antialiasing_kernel 318 | len_k = len(kernel) 319 | pad_k = (len_k + 1) // 2 320 | 321 | kernel = torch.as_tensor(kernel) 322 | kernel = torch.ger(kernel, kernel).to(img) 323 | kernel_flip = torch.flip(kernel, (0, 1)) 324 | 325 | img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad( 326 | img, p, pad_k, G 327 | ) 328 | 329 | p_ux1 = pad_x1 330 | p_ux2 = pad_x2 + 1 331 | p_uy1 = pad_y1 332 | p_uy2 = pad_y2 + 1 333 | w_p = img_pad.shape[3] - len_k + 1 334 | h_p = img_pad.shape[2] - len_k + 1 335 | h_o = img.shape[2] 336 | w_o = img.shape[3] 337 | 338 | img_2x = upfirdn2d(img_pad, kernel_flip, up=2) 339 | 340 | grid = make_grid( 341 | img_2x.shape, 342 | -2 * p_ux1 / w_o - 1, 343 | 2 * (w_p - p_ux1) / w_o - 1, 344 | -2 * p_uy1 / h_o - 1, 345 | 2 * (h_p - p_uy1) / h_o - 1, 346 | device=img_2x.device, 347 | ).to(img_2x) 348 | grid = affine_grid(grid, torch.inverse(G)[:, :2, :].to(img_2x)) 349 | grid = grid * torch.tensor( 350 | [w_o / w_p, h_o / h_p], device=grid.device 351 | ) + torch.tensor( 352 | [(w_o + 2 * p_ux1) / w_p - 1, (h_o + 2 * p_uy1) / h_p - 1], device=grid.device 353 | ) 354 | 355 | img_affine = F.grid_sample( 356 | img_2x, grid, mode="bilinear", align_corners=False, padding_mode="zeros" 357 | ) 358 | 359 | img_down = upfirdn2d(img_affine, kernel, down=2) 360 | 361 | end_y = -pad_y2 - 1 362 | if end_y == 0: 363 | end_y = img_down.shape[2] 364 | 365 | end_x = -pad_x2 - 1 366 | if end_x == 0: 367 | end_x = img_down.shape[3] 368 | 369 | img = img_down[:, :, pad_y1:end_y, pad_x1:end_x] 370 | 371 | return img, G 372 | 373 | 374 | def apply_color(img, mat): 375 | batch = img.shape[0] 376 | img = img.permute(0, 2, 3, 1) 377 | mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3) 378 | mat_add = mat[:, :3, 3].view(batch, 1, 1, 3) 379 | img = img @ mat_mul + mat_add 380 | img = img.permute(0, 3, 1, 2) 381 | 382 | return img 383 | 384 | 385 | def random_apply_color(img, p, C=None): 386 | if C is None: 387 | C = sample_color(p, img.shape[0]) 388 | 389 | img = apply_color(img, C.to(img)) 390 | 391 | return img, C 392 | 393 | 394 | def augment(img, p, transform_matrix=(None, None)): 395 | img, G = random_apply_affine(img, p, transform_matrix[0]) 396 | img, C = random_apply_color(img, p, transform_matrix[1]) 397 | 398 | return img, (G, C) 399 | -------------------------------------------------------------------------------- /op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is used to train the S2F GAN. 3 | The module structure is the following: 4 | - A print function used to record training log 5 | - A parser used to read the parameters from users. 6 | - Set torch home that is used by pytorch to store the pretrained models 7 | - Initialize S2F GAN and optmizers 8 | - A data_prefetecher is used to load the inputs to cuda during training. 9 | - A train function used to call and excute the script. 10 | The training logs will be stored in log.txt 11 | """ 12 | 13 | import os 14 | import argparse 15 | import numpy as np 16 | import torch 17 | from torch import nn 18 | from torch.nn import functional as F 19 | from torch.utils import data 20 | from torchvision import utils 21 | from model import Model as S2FGAN 22 | from dataset import CeleDataset 23 | from non_leaking import augment 24 | import time 25 | import datetime 26 | import torch.backends.cudnn as cudnn 27 | 28 | #Speed up training 29 | cudnn.benchmark = True 30 | 31 | #write the paramters to train S2FGAN in log.txt 32 | def print(x): 33 | with open("log.txt","a") as f: 34 | f.write(str(x) + "\n") 35 | 36 | def accumulate(model1, model2, decay=0.999): 37 | """ 38 | Return None 39 | Parameters 40 | ---------- 41 | model1 : pytorch model 42 | model2 : pytorch model 43 | decay : int, default 0.999, the speed of updating model1 parameter 44 | 45 | Returns 46 | ------- 47 | None 48 | Update model1 paramter by model2 paramter. 49 | """ 50 | par1 = dict(model1.named_parameters()) 51 | par2 = dict(model2.named_parameters()) 52 | 53 | for k in par1.keys(): 54 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 55 | 56 | class data_prefetcher(): 57 | ''' 58 | A wrapper of dataloader, to load the data to cuda and process it during S2F training. 59 | ''' 60 | def __init__(self, loader): 61 | """ 62 | Return None 63 | Parameters 64 | ---------- 65 | data : pytorch data loader. 66 | 67 | Returns 68 | ------- 69 | None 70 | 71 | Initialize cuda stream and preload the data when intialize the classes 72 | """ 73 | self.loader = iter(loader) 74 | self.stream = torch.cuda.Stream() 75 | self.preload() 76 | 77 | def preload(self): 78 | """ 79 | Return None 80 | Parameters 81 | ---------- 82 | None 83 | 84 | Returns 85 | ------- 86 | None 87 | load the data to cuda and process data using process function. Here is concurrent happens. 88 | """ 89 | try: 90 | self.next_input = next(self.loader) 91 | except StopIteration: 92 | self.next_input = None 93 | return 94 | with torch.cuda.stream(self.stream): 95 | self.next_input = [i.cuda(non_blocking=True) for i in self.next_input] 96 | 97 | def next(self): 98 | """ 99 | Return None 100 | Parameters 101 | ---------- 102 | None 103 | 104 | Returns 105 | ------- 106 | None 107 | Synchronise the stream, return preloaded data, and load data for next batch. 108 | """ 109 | torch.cuda.current_stream().wait_stream(self.stream) 110 | input = self.next_input 111 | self.preload() 112 | return input 113 | 114 | def sample_data(loader,device): 115 | """ 116 | Return normalized sketch, normalized images and label 117 | Parameters 118 | ---------- 119 | loader : pytorch loader 120 | device : cuda device name 121 | 122 | Returns 123 | ------- 124 | sketch : noramlised X. 125 | img : normalised img. 126 | labels : same 127 | """ 128 | while True: 129 | pref = data_prefetcher(loader) 130 | data = pref.next() 131 | while data is not None: 132 | [sketch,img,label] = data 133 | sketch = (sketch - 255 * 0.5) / (255 * 0.5) 134 | img = (img - 255 * 0.5) / (255 * 0.5) 135 | label = label 136 | data = pref.next() 137 | yield [sketch,img,label] 138 | 139 | 140 | def train(args, dataloader_train,dataloader_val, models, g_optim, d_optim, device): 141 | 142 | """ 143 | Return normalized sketch, normalized images and label 144 | Parameters 145 | ---------- 146 | args : args for S2FGAN 147 | dataloader_train : dataloader for training 148 | dataloader_val : dataloader for evaluation 149 | models : S2FGAN models 150 | g_optim : generator optimizer 151 | d_optim : discriminator optimizer 152 | device : cuda device 153 | 154 | Returns 155 | ------- 156 | None 157 | A trained S2FGAN. 158 | """ 159 | 160 | [model,model_ema] = models 161 | 162 | #speed data loading and process data 163 | loader = sample_data(dataloader_train, device) 164 | loader_val = sample_data(dataloader_val,device) 165 | 166 | print("Trianing start") 167 | 168 | loss_dict = {} 169 | 170 | model_module = model.module 171 | 172 | #intialize paramters for adaptive discriminator agumentation. 173 | accum = 0.5 ** (32 / (10 * 1000)) 174 | ada_augment = torch.tensor([0.0, 0.0], device=device) 175 | ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 176 | ada_aug_step = args.ada_target / args.ada_length 177 | r_t_stat = 0 178 | 179 | #Record starting time, to estimate the time need for training. 180 | start_time = time.time() 181 | 182 | for idx in range(args.iter): 183 | i = idx + args.start_iter 184 | 185 | if i > args.iter: 186 | print("Done!") 187 | 188 | break 189 | 190 | sketch,img,label = next(loader) 191 | 192 | #samples attribute shiting vector 193 | sampled_ratio = torch.FloatTensor(np.random.uniform(-4,4, (sketch.size(0), c_dim))).to(device) 194 | sampled_mask = torch.FloatTensor(np.random.randint(0,2, (sketch.size(0), 1)) * 1.0).to(device) 195 | sampled_ratio = sampled_ratio * sampled_mask 196 | target_ratio = (label * 2 - 1) + sampled_ratio 197 | target_mask = target_ratio >= 0 198 | 199 | #create domain label for sketch and img 200 | domain_sketch = torch.zeros((sketch.size(0),1)).type(torch.FloatTensor).to(device) 201 | domain_img = torch.ones((img.size(0),1)).type(torch.FloatTensor).to(device) 202 | 203 | 204 | fake_img_pred, real_img_pred, bce = model(img,sketch,sampled_ratio,label,target_mask, ada_aug_p = ada_aug_p, train_discriminator = True) 205 | 206 | d_loss = F.softplus(-real_img_pred).mean() + F.softplus(fake_img_pred).mean() + bce.mean() 207 | 208 | loss_dict["d_loss"] = d_loss 209 | d_optim.zero_grad() 210 | d_loss.backward() 211 | d_optim.step() 212 | 213 | for real_pred in [real_img_pred]: 214 | if args.augment and args.augment_p == 0: 215 | ada_augment_data = torch.tensor( 216 | (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device 217 | ) 218 | ada_augment += ada_augment_data 219 | 220 | if ada_augment[1] > 255: 221 | pred_signs, n_pred = ada_augment.tolist() 222 | r_t_stat = pred_signs / n_pred 223 | 224 | if r_t_stat > args.ada_target: 225 | sign = 1 226 | 227 | else: 228 | sign = -1 229 | ada_aug_p += sign * ada_aug_step * n_pred 230 | ada_aug_p = min(1, max(0, ada_aug_p)) 231 | ada_augment.mul_(0) 232 | 233 | d_regularize = i % args.d_reg_every == 0 234 | if d_regularize: 235 | img.requires_grad = True 236 | 237 | r1_loss = model(img, d_regularize = True) 238 | r1_loss = r1_loss.mean() 239 | 240 | d_optim.zero_grad() 241 | (args.r1 / 2 * r1_loss * args.d_reg_every).backward() 242 | d_optim.step() 243 | 244 | loss_dict["r1"] = r1_loss 245 | 246 | img.requires_grad = False 247 | 248 | #samples attribute shiting vector 249 | sampled_ratio = torch.FloatTensor(np.random.uniform(-4,4, (sketch.size(0), c_dim))).to(device) 250 | sampled_mask = torch.FloatTensor(np.random.randint(0,2, (sketch.size(0), 1)) * 1.0).to(device) 251 | sampled_ratio = sampled_ratio * sampled_mask 252 | target_ratio = (label * 2 - 1) + sampled_ratio 253 | target_mask = target_ratio >= 0 254 | 255 | 256 | g_loss = model(img,sketch,sampled_ratio,label,target_mask, domain_img,domain_sketch, ada_aug_p = ada_aug_p,train_generator = True) 257 | g_loss = g_loss.mean() 258 | 259 | loss_dict["g_loss"] = g_loss 260 | 261 | g_optim.zero_grad() 262 | g_loss.backward() 263 | g_optim.step() 264 | 265 | accumulate(model_ema, model_module, accum) 266 | 267 | loss_reduced = loss_dict 268 | 269 | d_loss = loss_reduced["d_loss"].item() 270 | g_loss = loss_reduced["g_loss"].item() 271 | r1 = loss_reduced["r1"].item() 272 | 273 | #Print log 274 | if i % 10 == 0: 275 | # Determine approximate time left 276 | batches_done = idx 277 | batches_left = args.iter - batches_done 278 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - start_time) / (batches_done + 1)) 279 | 280 | print( 281 | ( 282 | f"Epoch[{idx}/{args.iter}]; augment: {ada_aug_p:.4f}; " 283 | f"d_loss: {d_loss:.4f}; g_loss: {g_loss:.4f}; r1: {r1:.4f}; ETA: {time_left}" 284 | ) 285 | ) 286 | 287 | #sample images 288 | if i % 400 == 0: 289 | sketch,img,label = next(loader_val) 290 | with torch.no_grad(): 291 | samples = None 292 | for e, j,l in zip(sketch,img,torch.cat((LABELS,label[-2:]))): 293 | d = e.view(1,args.img_height,args.img_width).repeat(3,1,1) 294 | e = e.view(1,1,256,256).repeat(13,1,1,1) 295 | l = l.view(1,12).repeat(13,1) * SCALE 296 | k,im = model_ema(j.view(1,3,256,256),sketch = e,sampled_ratio = l,generate = True) 297 | im = torch.cat([x for x in im],-1) 298 | sample = torch.cat((d,k.view(3,256,256),j,im),-1).unsqueeze(0) 299 | samples = sample if samples is None else torch.cat((samples,sample),-2) 300 | 301 | utils.save_image( 302 | samples, 303 | f"sample/{str(i).zfill(6)}.png", 304 | nrow= 16, 305 | normalize=True, 306 | range=(-1, 1), 307 | ) 308 | 309 | # Save model checkpoints 310 | if i % 10000 == 0: 311 | torch.save( 312 | { 313 | "model" :model_module.state_dict(), 314 | "model_ema":model_ema.state_dict() 315 | }, 316 | f"checkpoint/{str(i).zfill(6)}.pt", 317 | ) 318 | 319 | 320 | if __name__ == "__main__": 321 | device = "cuda" 322 | 323 | parser = argparse.ArgumentParser(description="S2FGAN trainer") 324 | 325 | parser.add_argument( 326 | "--iter", 327 | type=int, 328 | default=100, 329 | help="total training iterations" 330 | ) 331 | parser.add_argument( 332 | "--batch", 333 | type=int, 334 | default = 4, 335 | help="batch sizes" 336 | ) 337 | 338 | parser.add_argument( 339 | "--r1", 340 | type=float, 341 | default=1, 342 | help="weight of the r1 regularization" 343 | ) 344 | 345 | parser.add_argument( 346 | "--d_reg_every", 347 | type=int, 348 | default=16, 349 | help="interval of the applying r1 regularization", 350 | ) 351 | 352 | parser.add_argument( 353 | "--lr", 354 | type=float, 355 | default=0.002, 356 | help="learning rate" 357 | ) 358 | 359 | parser.add_argument( 360 | "--augment", 361 | type=bool, 362 | default=True, 363 | help="apply discriminator augmentation" 364 | ) 365 | 366 | parser.add_argument( 367 | "--augment_p", 368 | type=float, 369 | default=0, 370 | help="probability of applying augmentation. 0 = use adaptive augmentation", 371 | ) 372 | parser.add_argument( 373 | "--ada_target", 374 | type=float, 375 | default=0.6, 376 | help="target augmentation probability for adaptive augmentation", 377 | ) 378 | 379 | parser.add_argument( 380 | "--ada_length", 381 | type=int, 382 | default=500 * 1000, 383 | help="target duraing to reach augmentation probability for adaptive augmentation", 384 | ) 385 | parser.add_argument( 386 | "--ada_every", 387 | type=int, 388 | default=256, 389 | help="probability update interval of the adaptive augmentation", 390 | ) 391 | 392 | parser.add_argument( 393 | "--img_height", 394 | type=int, 395 | default=256, 396 | help="size of image height" 397 | ) 398 | 399 | parser.add_argument( 400 | "--img_width", 401 | type=int, 402 | default=256, 403 | help="size of image width" 404 | ) 405 | parser.add_argument( 406 | "--NumOfImage", 407 | type=int, 408 | default= 10, 409 | help = "number of images in the zip" 410 | ) 411 | 412 | parser.add_argument( 413 | "--imageZip", 414 | type=str, 415 | default= "data/CelebAMask-HQ-Sample.zip" 416 | ) 417 | 418 | parser.add_argument( 419 | "--hedEdgeZip", 420 | type=str, 421 | default= "data/hed_edge_256-Sample.zip" 422 | ) 423 | 424 | parser.add_argument( 425 | "--hedEdgePath", 426 | type=str, 427 | default= "hed_edge_256-Sample" 428 | ) 429 | 430 | parser.add_argument( 431 | "--imagePath", 432 | type=str, 433 | default= "CelebAMask-HQ-Sample/CelebA-HQ-img" 434 | ) 435 | 436 | parser.add_argument( 437 | "--TORCH_HOME", 438 | type=str, 439 | default="None", 440 | help="where to load/save pytorch pretrained models" 441 | ) 442 | 443 | parser.add_argument( 444 | "--selected_attrs", 445 | type = list, 446 | nargs="+", 447 | help="selected attributes for the CelebAMask-HQ dataset", 448 | default=["Smiling", "Male","No_Beard", "Eyeglasses","Young", "Bangs", "Narrow_Eyes", "Pale_Skin", "Big_Lips","Big_Nose","Mustache","Chubby"], 449 | ) 450 | 451 | parser.add_argument( 452 | "--label_path", 453 | type = str, 454 | default = "data/CelebAMask-HQ-attribute-anno.txt", 455 | help = "attributes annotation text file of CelebAMask-HQ" 456 | ) 457 | 458 | parser.add_argument( 459 | "--ATMDTT", 460 | type = list, 461 | nargs="+", 462 | help="Attributes to manipulate during testing time", 463 | default= 464 | [[1,0,0,0,0,0,0,0,0,0,0,0], 465 | [0,1,0,0,0,0,0,0,0,0,0,0] 466 | ] 467 | ) 468 | 469 | parser.add_argument( 470 | "--model_type", 471 | type = int, 472 | default = 0, 473 | help = "0- S2F-DIS, 1- S2F-DEC" 474 | ) 475 | 476 | args = parser.parse_args() 477 | 478 | args.start_iter = 0 479 | c_dim = len(args.selected_attrs) 480 | 481 | #create folders to store samples and checkpoints 482 | os.makedirs("sample", exist_ok=True) 483 | os.makedirs("checkpoint", exist_ok=True) 484 | 485 | #Set TORCH_HOME to system enviroment. 486 | if args.TORCH_HOME != "None": 487 | os.environ['TORCH_HOME'] = args.TORCH_HOME 488 | 489 | 490 | #Sanity check of GPU installation 491 | if not torch.cuda.is_available(): 492 | raise SystemExit("GPU Required") 493 | 494 | #initialization 495 | model = S2FGAN(args,c_dim,augment).to(device) 496 | 497 | model_ema = S2FGAN(args,c_dim,augment).to(device) 498 | 499 | accumulate(model_ema, model, 0) 500 | model_ema.eval() 501 | 502 | #get model optimizer 503 | 504 | g_optim = model.g_optim 505 | d_optim = model.d_optim 506 | 507 | model = nn.DataParallel(model) 508 | #initialize dataloader 509 | 510 | dataset = CeleDataset(args, True) 511 | loader = data.DataLoader( 512 | dataset, 513 | batch_size=args.batch, 514 | num_workers = 4, 515 | drop_last = True 516 | ) 517 | 518 | dataset_val = CeleDataset(args, False) 519 | dataloader_val = torch.utils.data.DataLoader( 520 | dataset_val, 521 | batch_size=len(args.ATMDTT) + 2, 522 | num_workers=4 523 | ) 524 | 525 | 526 | #Intialise the intensity control parameters for demonstration 527 | LABELS = torch.FloatTensor(args.ATMDTT).to(device) 528 | SCALE = torch.FloatTensor([-4.0,-3.0, -2.0,-1.5, -1.0, -0.5,0,0.5, 1.0,1.5,2.0,3.0,4.0]).to(device).view(13,1) 529 | 530 | #start training 531 | train(args,loader,dataloader_val,[model,model_ema], g_optim, d_optim, device) 532 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is the concrete implementation of S2FGAN. 3 | This module structure is following: 4 | make_kernel is used to intialise the kernel for blurring image 5 | Blur, a layer used to apply blur kerbel to input 6 | PixelNorm, a layer used to apply pixel normalization 7 | EqualConv1d, convolution 1d with equalized learning trick 8 | EqualConv2d, convolution 2d with equalized learning trick 9 | Equallinear, linear layerwith equalized learning trick 10 | Embedding, attribute mapping networks. 11 | Encoder, the encoder of S2FGAN. 12 | StyledConv, the upblock for the decoder of S2FGAN. 13 | Discriminator, the discrimantor of S2FGAN. 14 | VGGPerceptualLoss, the perceptual loss based on VGG19. 15 | """ 16 | 17 | import math 18 | import torch 19 | import torchvision 20 | from torch import nn 21 | from torch.nn import functional as F 22 | from torch.autograd import Function 23 | from torch.nn.init import normal_ 24 | from torch import autograd, optim 25 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix 26 | 27 | 28 | #Pixel Normalization 29 | class PixelNorm(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def forward(self, input): 34 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 35 | 36 | #create blur kernel 37 | def make_kernel(k): 38 | k = torch.tensor(k, dtype=torch.float32) 39 | 40 | if k.ndim == 1: 41 | k = k[None, :] * k[:, None] 42 | 43 | k /= k.sum() 44 | 45 | return k 46 | 47 | 48 | #Blur Layer 49 | class Blur(nn.Module): 50 | def __init__(self, kernel, pad, upsample_factor=1): 51 | super().__init__() 52 | 53 | kernel = make_kernel(kernel) 54 | 55 | if upsample_factor > 1: 56 | kernel = kernel * (upsample_factor ** 2) 57 | 58 | self.register_buffer("kernel", kernel) 59 | 60 | self.pad = pad 61 | 62 | def forward(self, input): 63 | out = upfirdn2d(input, self.kernel, pad=self.pad) 64 | 65 | return out 66 | 67 | class Upsample(nn.Module): 68 | def __init__(self, kernel, factor=2): 69 | super().__init__() 70 | 71 | self.factor = factor 72 | kernel = make_kernel(kernel) * (factor ** 2) 73 | self.register_buffer("kernel", kernel) 74 | 75 | p = kernel.shape[0] - factor 76 | 77 | pad0 = (p + 1) // 2 + factor - 1 78 | pad1 = p // 2 79 | 80 | self.pad = (pad0, pad1) 81 | 82 | def forward(self, input): 83 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 84 | 85 | return out 86 | 87 | #Equlized convlution 2d 88 | class EqualConv2d(nn.Module): 89 | def __init__( 90 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 91 | ): 92 | """ 93 | Return, None 94 | Parameters 95 | ---------- 96 | in_channels, int, the channels of input 97 | out_channels, int, the channles expanded by the convolution 98 | kernel_size, int, the size of kernel needed. 99 | stride: int, controls the cross correlation during convolution 100 | padding: int, the number of gride used to pad input. 101 | bias: bool, controls adding of learnable biase 102 | Returns 103 | ------- 104 | None 105 | """ 106 | 107 | 108 | super().__init__() 109 | 110 | #intialize weight 111 | self.weight = nn.Parameter( 112 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 113 | ) 114 | 115 | #calculate sacles for weight 116 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 117 | 118 | self.stride = stride 119 | self.padding = padding 120 | 121 | #create bias 122 | if bias: 123 | self.bias = nn.Parameter(torch.zeros(out_channel)) 124 | 125 | else: 126 | self.bias = None 127 | 128 | def forward(self, input): 129 | """ 130 | Return, the convolutioned x. 131 | Parameters 132 | ---------- 133 | x: pytorch tensor, used for the input of convolution 134 | Returns 135 | ------- 136 | the convolutioned x 137 | """ 138 | 139 | out = conv2d_gradfix.conv2d( 140 | input, 141 | self.weight * self.scale, 142 | bias=self.bias, 143 | stride=self.stride, 144 | padding=self.padding, 145 | ) 146 | 147 | return out 148 | 149 | def __repr__(self): 150 | return ( 151 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 152 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 153 | ) 154 | 155 | 156 | class EqualLinear(nn.Module): 157 | def __init__( 158 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 159 | ): 160 | """ 161 | Return, None 162 | Parameters 163 | ---------- 164 | in_dim, int, number of features for input 165 | out_dim, int, number of features for output 166 | bias: bool, controls adding of learnable biase 167 | lr_mul: int, the scales of biase 168 | activation: bool, controls the use of leakly relu. 169 | Returns 170 | ------- 171 | None 172 | """ 173 | 174 | super().__init__() 175 | 176 | #initialize weight 177 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 178 | 179 | #create bias 180 | if bias: 181 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 182 | 183 | else: 184 | self.bias = None 185 | 186 | #store activation function 187 | self.activation = activation 188 | 189 | #calculate sacles for weight 190 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 191 | self.lr_mul = lr_mul 192 | 193 | def forward(self, input): 194 | """ 195 | Return, the transformed x. 196 | Parameters 197 | ---------- 198 | x: pytorch tensor, used for the input of linear. 199 | Returns 200 | ------- 201 | the transformed x. 202 | """ 203 | 204 | if self.activation: 205 | out = F.linear(input, self.weight * self.scale) 206 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 207 | 208 | else: 209 | out = F.linear( 210 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 211 | ) 212 | 213 | return out 214 | 215 | def __repr__(self): 216 | return ( 217 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 218 | ) 219 | 220 | 221 | class ModulatedConv2d(nn.Module): 222 | def __init__( 223 | self, 224 | in_channel, 225 | out_channel, 226 | kernel_size, 227 | style_dim, 228 | demodulate=True, 229 | upsample=False, 230 | downsample=False, 231 | blur_kernel=[1, 3, 3, 1] 232 | ): 233 | """ 234 | Return, None 235 | Parameters 236 | ---------- 237 | in_channels, int, the channels of input 238 | out_channels, int, the channles expanded by the convolution 239 | kernel_size, int, the size of kernel needed. 240 | style_dim, int, dimensionality of attribute latent space. 241 | demodulate, int, decide applying demodulation 242 | upsample, bool, decide if upsample the input 243 | downsample, bool, decide if downsample the input 244 | blur_kernel, [int], the kernel used to blur input. 245 | Returns 246 | ------- 247 | None 248 | """ 249 | 250 | 251 | super().__init__() 252 | 253 | self.eps = 1e-8 254 | self.kernel_size = kernel_size 255 | self.in_channel = in_channel 256 | self.out_channel = out_channel 257 | self.upsample = upsample 258 | self.downsample = downsample 259 | 260 | if upsample: 261 | factor = 2 262 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 263 | pad0 = (p + 1) // 2 + factor - 1 264 | pad1 = p // 2 + 1 265 | 266 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 267 | 268 | if downsample: 269 | factor = 2 270 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 271 | pad0 = (p + 1) // 2 272 | pad1 = p // 2 273 | 274 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 275 | 276 | fan_in = in_channel * kernel_size ** 2 277 | self.scale = 1 / math.sqrt(fan_in) 278 | self.padding = kernel_size // 2 279 | 280 | self.weight = nn.Parameter( 281 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 282 | ) 283 | 284 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 285 | 286 | self.demodulate = demodulate 287 | 288 | def forward(self, input, style): 289 | """ 290 | Return, the transformed x. 291 | Parameters 292 | ---------- 293 | x: pytorch tensor. for appearance latent space. 294 | style: pytorch tensor. for attribute editing latent space. 295 | Returns 296 | ------- 297 | the transformed x. 298 | """ 299 | 300 | batch, in_channel, height, width = input.shape 301 | 302 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 303 | weight = self.scale * self.weight * style 304 | 305 | if self.demodulate: 306 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 307 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 308 | 309 | weight = weight.view( 310 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 311 | ) 312 | 313 | if self.upsample: 314 | input = input.view(1, batch * in_channel, height, width) 315 | weight = weight.view( 316 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 317 | ) 318 | weight = weight.transpose(1, 2).reshape( 319 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 320 | ) 321 | out = conv2d_gradfix.conv_transpose2d( 322 | input, weight, padding=0, stride=2, groups=batch 323 | ) 324 | _, _, height, width = out.shape 325 | out = out.view(batch, self.out_channel, height, width) 326 | out = self.blur(out) 327 | 328 | elif self.downsample: 329 | input = self.blur(input) 330 | _, _, height, width = input.shape 331 | input = input.view(1, batch * in_channel, height, width) 332 | out = conv2d_gradfix.conv2d( 333 | input, weight, padding=0, stride=2, groups=batch 334 | ) 335 | _, _, height, width = out.shape 336 | out = out.view(batch, self.out_channel, height, width) 337 | 338 | else: 339 | input = input.view(1, batch * in_channel, height, width) 340 | out = conv2d_gradfix.conv2d( 341 | input, weight, padding=self.padding, groups=batch 342 | ) 343 | _, _, height, width = out.shape 344 | out = out.view(batch, self.out_channel, height, width) 345 | 346 | return out 347 | 348 | #trainable input layer for decoder 349 | class ConstantInput(nn.Module): 350 | def __init__(self, channel, size=4): 351 | super().__init__() 352 | 353 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 354 | 355 | def forward(self, input): 356 | batch = input.shape[0] 357 | out = self.input.repeat(batch, 1, 1, 1) 358 | 359 | return out 360 | 361 | 362 | class StyledConv(nn.Module): 363 | def __init__( 364 | self, 365 | in_channel, 366 | out_channel, 367 | kernel_size, 368 | style_dim, 369 | blur_kernel=[1, 3, 3, 1], 370 | demodulate=True, 371 | ): 372 | """ 373 | Return, None 374 | Parameters 375 | ---------- 376 | in_channels, int, the channels of input 377 | out_channels, int, the channles expanded by the convolution 378 | kernel_size, int, the size of kernel needed. 379 | style_dim, int, dimensionality of attribute latent space. 380 | upsample, bool, decide if upsample the input 381 | blur_kernel, [int], the kernel used to blur input. 382 | demoulated, bool, decide applying demodulation 383 | Returns 384 | ------- 385 | None 386 | """ 387 | 388 | super().__init__() 389 | 390 | self.conv1 = ModulatedConv2d( 391 | in_channel, 392 | out_channel, 393 | kernel_size, 394 | style_dim, 395 | upsample=True, 396 | blur_kernel=blur_kernel, 397 | demodulate=demodulate, 398 | ) 399 | 400 | self.activate1 = FusedLeakyReLU(out_channel) 401 | 402 | self.conv2 = ModulatedConv2d( 403 | out_channel, 404 | out_channel, 405 | kernel_size, 406 | style_dim, 407 | upsample=False, 408 | blur_kernel=blur_kernel, 409 | demodulate=demodulate, 410 | ) 411 | 412 | self.activate2 = FusedLeakyReLU(out_channel) 413 | 414 | def forward(self, input, style): 415 | """ 416 | Return, the transformed x. 417 | Parameters 418 | ---------- 419 | x: pytorch tensor. latent code of appearance latent space. 420 | style: pytorch tensor, latent code of attribute editing latent space. 421 | Returns 422 | ------- 423 | x, pytorch tensor, the transformed x. 424 | """ 425 | out = self.conv1(input, style) 426 | out = self.activate1(out) 427 | out = self.conv2(out,style) 428 | out = self.activate2(out) 429 | return out 430 | 431 | 432 | class EqualConv1d(nn.Module): 433 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): 434 | super().__init__() 435 | 436 | """ 437 | Return, None 438 | Parameters 439 | ---------- 440 | in_channels, int, the channels of input 441 | out_channels, int, the channles expanded by the convolution 442 | kernel_size, int, the size of kernel needed. 443 | stride: int, controls the cross correlation during convolution 444 | padding: int, the number of gride used to pad input. 445 | bias: bool, controls adding of learnable biase 446 | Returns 447 | ------- 448 | None 449 | """ 450 | 451 | self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size)) 452 | self.scale = 2 / math.sqrt(in_channel * out_channel * kernel_size) 453 | 454 | self.stride = stride 455 | self.padding = padding 456 | 457 | if bias: 458 | self.bias = nn.Parameter(torch.zeros(out_channel)) 459 | else: 460 | self.bias = None 461 | 462 | def forward(self,x): 463 | """ 464 | Return, the convolutioned x. 465 | Parameters 466 | ---------- 467 | x: pytorch tensor, used for the input of convolution 468 | Returns 469 | ------- 470 | the convolutioned x 471 | """ 472 | x = F.conv1d(x, self.weight * self.scale,bias=self.bias, stride=self.stride, padding=self.padding) 473 | return x 474 | 475 | class ToRGB(nn.Module): 476 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 477 | super().__init__() 478 | 479 | if upsample: 480 | self.upsample = Upsample(blur_kernel) 481 | 482 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 483 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 484 | 485 | def forward(self, input, style, skip=None): 486 | out = self.conv(input, style) 487 | out = out + self.bias 488 | 489 | if skip is not None: 490 | skip = self.upsample(skip) 491 | 492 | out = out + skip 493 | 494 | return out 495 | 496 | #Block for Attribute Mapping Network 497 | class Modify(nn.Module): 498 | def __init__(self, in_channel): 499 | super().__init__() 500 | self.model = nn.Sequential( 501 | EqualConv1d(in_channel, 64, 3,padding = 1, bias=False), 502 | nn.LeakyReLU(0.2, inplace = True), 503 | EqualConv1d(64, 64, 3,padding = 1, bias=False), 504 | nn.LeakyReLU(0.2, inplace = True), 505 | ) 506 | 507 | self.w = EqualConv1d(64, 64, 3,padding = 1, bias=False) 508 | self.h = EqualConv1d(64, 64, 3,padding = 1, bias=False) 509 | self.n = EqualConv1d(64, 64, 3, padding = 1, bias=False) 510 | 511 | self.skip = EqualConv1d(in_channel, 64, 1, bias=False) 512 | 513 | def forward(self,input): 514 | x = self.model(input) 515 | f = self.w(x) 516 | f = f / (torch.norm(f,p=2,dim = 1,keepdim= True) + 1e-8) 517 | x = self.n(f.bmm(f.permute(0,2,1)).bmm(self.h(x))) 518 | return x + self.skip(input) 519 | 520 | #Attribute Mapping Network 521 | class Embeding(nn.Module): 522 | def __init__(self, c_dim): 523 | super().__init__() 524 | self.directions = nn.Parameter(torch.zeros(1, c_dim, 512)) 525 | self.b1 = Modify(c_dim + 1) 526 | self.b2 = Modify(64) 527 | self.b3 = Modify(64) 528 | self.b4 = Modify(64) 529 | self.b5 = EqualConv1d(64, 1, 1, bias=False) 530 | 531 | def forward(self,x,a, reg = False): 532 | d = self.directions.repeat(a.size(0),1,1) 533 | is_reconstruct = ((a.sum(1, keepdim = True) != 0.0).float()).view(a.size(0),1,1) 534 | d = torch.cat((d * a.view(-1,a.size(1),1),x.view(x.size(0),1,512) * is_reconstruct),1) 535 | d = self.b1(d) 536 | d = self.b2(d) 537 | d = self.b3(d) 538 | d = self.b4(d) 539 | d = self.b5(d).view(-1,512) 540 | if reg: 541 | return d 542 | else: 543 | return x + d 544 | 545 | #encoder 546 | class Encoder(nn.Module): 547 | def __init__(self, in_channels=1, dim=64, n_downsample = 5, max_dim = 512, noise = False): 548 | super().__init__() 549 | 550 | pool_size = { 551 | 32 : 4, 552 | 64 : 3, 553 | 128 : 2, 554 | 256 : 2, 555 | 512 : 1, 556 | } 557 | 558 | self.vision = ConvLayer(in_channels,dim,1) 559 | 560 | conv_layers = [] 561 | linear_layers = [] 562 | # Downsampling 563 | dim_cur = dim 564 | dim_next = dim * 2 565 | for _ in range(n_downsample): 566 | conv_layers += [ 567 | nn.Sequential( 568 | ResBlock(dim_cur,dim_next), 569 | ResBlock(dim_next,dim_next,downsample= False) 570 | ) 571 | ] 572 | 573 | linear_layers += [nn.Sequential( 574 | nn.AdaptiveAvgPool2d(pool_size[dim_next]), 575 | nn.Flatten(), 576 | EqualLinear(dim_next * pool_size[dim_next] ** 2, 512, lr_mul = 0.01, activation="fused_lrelu"), 577 | *[EqualLinear(512, 512, lr_mul = 0.01, activation="fused_lrelu") for _ in range(3)] 578 | ) 579 | ] 580 | 581 | dim_cur = dim_next 582 | dim_next = min(max_dim,dim_next * 2) 583 | 584 | self.model = nn.ModuleList(conv_layers) 585 | self.linear = nn.ModuleList(linear_layers) 586 | self.norm = PixelNorm() 587 | extra_dimension = 100 if noise else 0 588 | self.final = nn.Sequential( 589 | EqualLinear(512 + extra_dimension, 512, lr_mul = 0.01, activation="fused_lrelu"), 590 | *[EqualLinear(512, 512, lr_mul = 0.01, activation="fused_lrelu") for _ in range(4)] 591 | ) 592 | 593 | def forward(self, x, noise = None): 594 | 595 | x = self.vision(x) 596 | style = 0 597 | 598 | for index in range(len(self.model)): 599 | x = self.model[index](x) 600 | style += self.linear[index](x) 601 | style = style / (index + 1) 602 | style = self.norm(style) 603 | if noise != None: 604 | noise = self.norm(noise) 605 | style = torch.cat((style,noise),1) 606 | style = self.final(style) 607 | return style 608 | 609 | 610 | #decoder 611 | class Generator(nn.Module): 612 | def __init__( 613 | self, 614 | c_dim, 615 | style_dim = 512, 616 | n_mlp = 8, 617 | channel_multiplier= 1, 618 | blur_kernel=[1, 3, 3, 1], 619 | lr_mlp=0.01, 620 | ): 621 | super().__init__() 622 | 623 | self.channels = { 624 | 4: 512, 625 | 8: 512, 626 | 16: 512, 627 | 32: 512, 628 | 64: 256 * channel_multiplier, 629 | 128: 128 * channel_multiplier, 630 | 256: 64 * channel_multiplier, 631 | 512: 32 * channel_multiplier, 632 | 1024: 16 * channel_multiplier, 633 | } 634 | 635 | self.input = ConstantInput(self.channels[4]) 636 | self.conv1 = ModulatedConv2d( 637 | 512, 638 | 512, 639 | 3, 640 | style_dim, 641 | upsample= False, 642 | blur_kernel=blur_kernel, 643 | demodulate=True, 644 | ) 645 | self.activate1 = FusedLeakyReLU(512) 646 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 647 | 648 | self.convs = nn.ModuleList([ 649 | StyledConv(512,512,3,style_dim,blur_kernel), #4 - 8 650 | StyledConv(512,512,3,style_dim,blur_kernel), #8 - 16 651 | StyledConv(512,512,3,style_dim,blur_kernel), #16 - 32 652 | StyledConv(512,256 * channel_multiplier,3,style_dim,blur_kernel), #32 - 64 653 | StyledConv(256 * channel_multiplier, 128 * channel_multiplier,3,style_dim,blur_kernel), #64 - 128 654 | StyledConv(128 * channel_multiplier, 64 * channel_multiplier,3,style_dim,blur_kernel), #128 - 256 655 | ]) 656 | 657 | self.to_rgbs = nn.ModuleList([ 658 | ToRGB(512, style_dim), #8 659 | ToRGB(512, style_dim), #16 660 | ToRGB(512, style_dim), #32 661 | ToRGB(256 * channel_multiplier, style_dim), #64 662 | ToRGB(128 * channel_multiplier, style_dim), #128 663 | ToRGB(64 * channel_multiplier, style_dim), #256 664 | ]) 665 | 666 | def forward(self,style): 667 | x = self.input(style) 668 | x = self.conv1(x,style) 669 | x = self.activate1(x) 670 | skip = self.to_rgb1(x,style) 671 | 672 | for index in range(len(self.convs)): 673 | x = self.convs[index](x,style) 674 | skip = self.to_rgbs[index](x,style,skip) 675 | return skip 676 | 677 | 678 | #convolution layer with dowmsample and activation function 679 | class ConvLayer(nn.Sequential): 680 | def __init__( 681 | self, 682 | in_channel, 683 | out_channel, 684 | kernel_size, 685 | downsample=False, 686 | blur_kernel=[1, 3, 3, 1], 687 | bias=True, 688 | activate=True, 689 | ): 690 | layers = [] 691 | 692 | if downsample: 693 | factor = 2 694 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 695 | pad0 = (p + 1) // 2 696 | pad1 = p // 2 697 | 698 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 699 | 700 | stride = 2 701 | self.padding = 0 702 | 703 | else: 704 | stride = 1 705 | self.padding = kernel_size // 2 706 | 707 | layers.append( 708 | EqualConv2d( 709 | in_channel, 710 | out_channel, 711 | kernel_size, 712 | padding=self.padding, 713 | stride=stride, 714 | bias=bias and not activate, 715 | ) 716 | ) 717 | 718 | if activate: 719 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 720 | 721 | super().__init__(*layers) 722 | 723 | #residual block 724 | class ResBlock(nn.Module): 725 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample = True): 726 | super().__init__() 727 | 728 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 729 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample) 730 | 731 | self.skip = ConvLayer( 732 | in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False 733 | ) 734 | 735 | def forward(self, input): 736 | out = self.conv1(input) 737 | out = self.conv2(out) 738 | 739 | skip = self.skip(input) 740 | out = (out + skip) / math.sqrt(2) 741 | 742 | return out 743 | 744 | #domain discriminator 745 | class GradReverse(Function): 746 | @staticmethod 747 | def forward(ctx, x, beta = 1.0): 748 | ctx.beta = beta 749 | return x.view_as(x) 750 | 751 | @staticmethod 752 | def backward(ctx, grad_output): 753 | grad_input = grad_output.neg() * ctx.beta 754 | return grad_input, None 755 | 756 | class Linear(nn.Module): 757 | def __init__(self, in_dim, out_dim): 758 | super().__init__() 759 | 760 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) 761 | normal_(self.weight, 0, 0.001) 762 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(0)) 763 | self.scale = (1 / math.sqrt(in_dim)) 764 | 765 | def forward(self, input): 766 | out = F.linear(input, self.weight * self.scale, bias=self.bias) 767 | return out 768 | 769 | class Domain_Discriminator(nn.Module): 770 | def __init__(self): 771 | super().__init__() 772 | self.feature = Linear(512, 512) 773 | self.relu = nn.ReLU(inplace = True) 774 | self.fc = Linear(512, 1) 775 | 776 | def forward(self,x): 777 | x = GradReverse.apply(x) 778 | x = self.feature(x) 779 | x = self.relu(x) 780 | x = self.fc(x) 781 | return x 782 | 783 | class Classifier(nn.Module): 784 | def __init__(self,c_dim): 785 | super().__init__() 786 | self.W = nn.Parameter(torch.randn(512, c_dim)) 787 | self.c_dim = c_dim 788 | nn.init.xavier_uniform_(self.W.data, gain=1) 789 | 790 | 791 | def forward(self,x, ortho = False): 792 | self.W_norm = self.W / self.W.norm(dim=0) 793 | if not ortho: 794 | return torch.matmul(x,self.W_norm) 795 | else: 796 | return torch.matmul(x,self.W_norm), nn.L1Loss()(self.W_norm.transpose(1,0).matmul(self.W_norm), torch.diag(torch.ones(self.c_dim,device = x.device))) 797 | 798 | def edit(self, x, a): 799 | self.W_norm = self.W / self.W.norm(dim=0) 800 | d = self.W_norm.view(1,512,-1) 801 | a = a.view(a.size(0),1,-1) 802 | return x + (d * a).sum(-1) 803 | 804 | 805 | #model discriminator 806 | class Discriminator(nn.Module): 807 | def __init__(self, in_channels, c_dim, model_type, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]): 808 | super().__init__() 809 | 810 | self.convs = nn.Sequential( 811 | ConvLayer(in_channels, 64 * channel_multiplier, 1), #256 812 | ResBlock(64 * channel_multiplier, 128 * channel_multiplier), #256 - 128 813 | ResBlock(128 * channel_multiplier, 256 * channel_multiplier), #128 - 64 814 | ResBlock(256 * channel_multiplier, 512), #64 - 32 815 | ResBlock(512, 512), #32 - 16 816 | ResBlock(512, 512), #16 - 8 817 | ResBlock(512, 512) #8 - 4 818 | ) 819 | 820 | self.final_linear = nn.Sequential( 821 | EqualLinear(512 * 4 * 4, 512, activation="fused_lrelu"), 822 | EqualLinear(512, 1), 823 | ) 824 | 825 | if model_type == 1: 826 | self.W = nn.Sequential( 827 | EqualLinear(512 * 4 * 4, 512, activation="fused_lrelu"), 828 | EqualLinear(512, c_dim), 829 | ) 830 | 831 | self.model_type = model_type 832 | 833 | def forward(self, input): 834 | out = self.convs(input) 835 | batch, channel, height, width = out.shape 836 | out = out.view(batch, -1) 837 | if self.model_type == 0: 838 | return self.final_linear(out), (out * 0).detach() 839 | else: 840 | return self.final_linear(out), self.W(out) 841 | 842 | 843 | def requires_grad(model, flag=True): 844 | """ 845 | Return None 846 | Parameters 847 | ---------- 848 | model : pytorch model 849 | flag : bool, default true 850 | 851 | Returns 852 | ------- 853 | None 854 | 855 | set requires_grad flag for model 856 | 857 | """ 858 | 859 | for p in model.parameters(): 860 | p.requires_grad = flag 861 | 862 | 863 | #calculate generator loss 864 | def g_nonsaturating_loss(fake_pred): 865 | loss = F.softplus(-fake_pred).mean() 866 | 867 | return loss 868 | 869 | #VGG Perceptual loss 870 | class VGGPerceptualLoss(torch.nn.Module): 871 | def __init__(self): 872 | super().__init__() 873 | blocks = [] 874 | model = torchvision.models.vgg19(pretrained=True) 875 | blocks.append(model.features[:2].eval()) 876 | blocks.append(model.features[2:7].eval()) 877 | blocks.append(model.features[7:12].eval()) 878 | blocks.append(model.features[12:21].eval()) 879 | blocks.append(model.features[21:30].eval()) 880 | blocks = nn.ModuleList(blocks) 881 | self.blocks = torch.nn.ModuleList(blocks) 882 | self.weights = [1/32.0,1.0/16, 1.0/8, 1.0/4, 1.0] 883 | 884 | for p in self.parameters(): 885 | p.requires_grad = False 886 | 887 | def forward(self, input, target): 888 | if input.shape[1] != 3: 889 | input = input.repeat(1, 3, 1, 1) 890 | target = target.repeat(1, 3, 1, 1) 891 | loss = 0.0 892 | x = input 893 | y = target 894 | for i,block in enumerate(self.blocks): 895 | x = block(x) 896 | y = block(y) 897 | loss += torch.nn.functional.l1_loss(x, y) * self.weights[i] 898 | return loss 899 | 900 | #The function is used downsample and binarize the input 901 | def downsample(masks): 902 | masks = F.interpolate(masks,scale_factor= 1/2, mode="bilinear",align_corners=True,recompute_scale_factor=True) 903 | m = masks >= 0 #.5 904 | masks[m] = 1 905 | masks[~m] = 0 906 | return masks 907 | 908 | #calculte r1 loss 909 | def d_r1_loss(real_pred, real_img): 910 | grad_real, = autograd.grad( 911 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 912 | ) 913 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 914 | 915 | return grad_penalty 916 | 917 | class Model(nn.Module): 918 | def __init__(self, args,c_dim, augment): 919 | super().__init__() 920 | self.args = args 921 | self.encoder_sketch = Encoder(1,128, 5) 922 | self.encoder_img = Encoder(3,64, 6) 923 | self.generator = Generator(c_dim) 924 | self.classifier = Classifier(c_dim) 925 | if args.model_type == 1: 926 | self.edit = Embeding(c_dim) 927 | self.img_discriminator = Discriminator(3,c_dim,args.model_type) 928 | self.domain_discriminator = Domain_Discriminator() 929 | self.vgg = VGGPerceptualLoss() 930 | self.augment = augment 931 | 932 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) 933 | 934 | 935 | if args.model_type == 0: 936 | 937 | self.g_optim = optim.Adam( 938 | [{'params' : list(self.encoder_sketch.parameters()) + list(self.encoder_img.parameters()) + list(self.generator.parameters())}, 939 | {'params' : self.classifier.parameters(),"betas": (0.9,0.999), "weight_decay": 0.0005}, 940 | {'params' : list(self.domain_discriminator.parameters()),"betas": (0.9,0.999), "weight_decay": 0.0005} 941 | ], 942 | lr= args.lr, 943 | betas=(0, 0.99) 944 | ) 945 | 946 | self.d_optim = optim.Adam( 947 | self.img_discriminator.parameters(), 948 | lr=args.lr * d_reg_ratio, 949 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), 950 | ) 951 | 952 | else: 953 | self.g_optim = optim.Adam( 954 | [{'params' : list(self.encoder_sketch.parameters()) + list(self.encoder_img.parameters()) + list(self.edit.parameters()) + list(self.generator.parameters())}, 955 | {'params' : list(self.domain_discriminator.parameters()),"betas": (0.9,0.999), "weight_decay": 0.0005} 956 | ], 957 | lr= args.lr, 958 | betas=(0, 0.99), 959 | ) 960 | 961 | self.d_optim = optim.Adam( 962 | [{'params' : self.img_discriminator.parameters()}, 963 | {'params' : self.classifier.parameters(),"betas": (0.9,0.999), "weight_decay": 0.0005}], 964 | lr=args.lr * d_reg_ratio, 965 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio) 966 | ) 967 | 968 | def forward(self, img = None,sketch = None,sampled_ratio = None, label = None, target_mask = None, domain_img = None, domain_sketch = None, ada_aug_p = None, noise = None,train_discriminator = False, d_regularize = False, train_generator = False, generate = False): 969 | augment = self.augment 970 | if train_discriminator or d_regularize: 971 | requires_grad(self.encoder_sketch, False) 972 | requires_grad(self.encoder_img, False) 973 | requires_grad(self.generator, False) 974 | requires_grad(self.domain_discriminator, False) 975 | requires_grad(self.img_discriminator, True) 976 | if self.args.model_type == 1: 977 | requires_grad(self.edit, False) 978 | requires_grad(self.classifier, False) 979 | else: 980 | requires_grad(self.classifier, True) 981 | 982 | 983 | else: 984 | requires_grad(self.encoder_sketch, True) 985 | requires_grad(self.encoder_img, True) 986 | requires_grad(self.generator, True) 987 | requires_grad(self.domain_discriminator, True) 988 | requires_grad(self.img_discriminator, False) 989 | if self.args.model_type == 1: 990 | requires_grad(self.edit, True) 991 | requires_grad(self.classifier, True) 992 | else: 993 | requires_grad(self.classifier, False) 994 | 995 | if train_discriminator: 996 | 997 | if self.args.model_type == 0: 998 | img_latent = self.encoder_img(img) 999 | fake_img = self.generator(img_latent) 1000 | 1001 | if self.args.augment: 1002 | real_img_aug, _ = augment(img, ada_aug_p) 1003 | fake_img, _ = augment(fake_img, ada_aug_p) 1004 | else: 1005 | real_img_aug = img 1006 | 1007 | fake_img_pred, _ = self.img_discriminator(fake_img) 1008 | real_img_pred, bce = self.img_discriminator(real_img_aug) 1009 | 1010 | return fake_img_pred, real_img_pred, bce 1011 | else: 1012 | img_latent = self.encoder_img(img) 1013 | img_latent_1 = self.edit(img_latent, sampled_ratio) 1014 | fake_img = self.generator(img_latent_1) 1015 | 1016 | bce = nn.MSELoss()(self.classifier(img_latent), label * 2 - 1) 1017 | 1018 | if self.args.augment: 1019 | real_img_aug, _ = augment(img, ada_aug_p) 1020 | fake_img, _ = augment(fake_img, ada_aug_p) 1021 | else: 1022 | real_img_aug = img 1023 | 1024 | fake_img_pred, _ = self.img_discriminator(fake_img) 1025 | real_img_pred, real_class = self.img_discriminator(real_img_aug) 1026 | 1027 | outer_bce = nn.BCEWithLogitsLoss()(real_class, label) 1028 | 1029 | return fake_img_pred, real_img_pred, bce + outer_bce * 0.0 1030 | 1031 | if d_regularize: 1032 | real_pred_img, _ = self.img_discriminator(img) 1033 | r1_loss = d_r1_loss(real_pred_img,img) 1034 | return r1_loss 1035 | 1036 | if train_generator: 1037 | img_latent = self.encoder_img(img) 1038 | sketch_latent = self.encoder_sketch(downsample(sketch)) 1039 | sketch_loss = nn.L1Loss()(sketch_latent, img_latent.detach()) 1040 | reconstruct_img = self.generator(img_latent) 1041 | vgg_loss = self.vgg(reconstruct_img,img) 1042 | reconstruct_loss = nn.L1Loss()(reconstruct_img,img) 1043 | domain_loss = nn.BCEWithLogitsLoss()(self.domain_discriminator(img_latent.detach()), domain_img) + \ 1044 | nn.BCEWithLogitsLoss()(self.domain_discriminator(sketch_latent), domain_sketch) 1045 | 1046 | if self.args.model_type == 0: 1047 | 1048 | bce,orthologoy = self.classifier(img_latent, True) 1049 | bce = nn.MSELoss()(bce, label * 2 - 1) 1050 | 1051 | if self.args.augment: 1052 | reconstruct_img, GC = augment(reconstruct_img, ada_aug_p) 1053 | 1054 | fake_pred_img, _ = self.img_discriminator(reconstruct_img) 1055 | g_loss_img = g_nonsaturating_loss(fake_pred_img) 1056 | 1057 | g_total = sketch_loss * 2.5 +\ 1058 | domain_loss * 0.1 +\ 1059 | vgg_loss * 2.5 +\ 1060 | reconstruct_loss * 2.5 +\ 1061 | g_loss_img +\ 1062 | bce * 0.5 +\ 1063 | orthologoy 1064 | return g_total 1065 | 1066 | else: 1067 | img_latent_1 = self.edit(img_latent, sampled_ratio) 1068 | sketch_latent_1 = self.edit(sketch_latent, sampled_ratio) 1069 | 1070 | fake_img = self.generator(img_latent_1) 1071 | reg = self.edit(img_latent, sampled_ratio * 0.0, reg = True).abs().mean() 1072 | 1073 | latent_reconstruct = (self.edit(self.edit(img_latent.detach(),sampled_ratio), -sampled_ratio) - img_latent.detach()).abs().mean() 1074 | base_score = self.classifier(img_latent).detach() + sampled_ratio 1075 | edit_loss = nn.MSELoss()(self.classifier(img_latent_1), base_score) 1076 | domain_loss = domain_loss + \ 1077 | nn.BCEWithLogitsLoss()(self.domain_discriminator(img_latent_1.detach()), domain_img) + \ 1078 | nn.BCEWithLogitsLoss()(self.domain_discriminator(sketch_latent_1),domain_sketch) 1079 | 1080 | if self.args.augment: 1081 | fake_img, _ = augment(fake_img, ada_aug_p) 1082 | 1083 | fake_pred_img, fake_class = self.img_discriminator(fake_img) 1084 | g_loss_img = g_nonsaturating_loss(fake_pred_img) 1085 | 1086 | outer_edit = nn.BCEWithLogitsLoss()(fake_class,target_mask * 1.0) 1087 | 1088 | g_total = vgg_loss * 2.5 +\ 1089 | reg * 0.1 +\ 1090 | (edit_loss + outer_edit * 0.0) * 1.0 +\ 1091 | reconstruct_loss * 1.0 +\ 1092 | latent_reconstruct * 1.0 +\ 1093 | sketch_loss * 2.5 +\ 1094 | domain_loss * 0.05 +\ 1095 | g_loss_img 1096 | 1097 | return g_total 1098 | 1099 | if generate: 1100 | img = self.encoder_img(img) 1101 | sketch = self.encoder_sketch(downsample(sketch)) 1102 | if self.args.model_type == 0: 1103 | sketch = self.classifier.edit(sketch,sampled_ratio) 1104 | else: 1105 | sketch = self.edit(sketch,sampled_ratio) 1106 | img = self.generator(img) 1107 | sketch = self.generator(sketch) 1108 | return img,sketch 1109 | 1110 | 1111 | 1112 | 1113 | 1114 | 1115 | --------------------------------------------------------------------------------