├── .gitignore ├── .vscode └── settings.json ├── README.md ├── benchmark ├── DiffJPEG │ ├── DiffJPEG.py │ ├── LICENSE │ ├── README.md │ ├── diffjpeg.png │ ├── modules │ │ ├── __init__.py │ │ ├── compression.py │ │ └── decompression.py │ ├── requirements.txt │ └── utils.py ├── FFHQ-X_crops128_ncrops1000.npz ├── __init__.py ├── config.py ├── degradations.py ├── eval.py ├── prelude.py └── tasks.py ├── cli.py ├── datasets └── samples │ ├── sample_1.png │ └── sample_2.png ├── pretrained_networks └── put-pretrained-models-here ├── pytorch_fid ├── LICENSE ├── __init__.py ├── __main__.py ├── fid_score.py └── inception.py ├── robust_unsupervised ├── __init__.py ├── io_utils.py ├── loss_function.py ├── optimizer.py ├── prelude.py └── variables.py ├── run.py └── stylegan2_ada ├── LICENSE ├── dnnlib ├── __init__.py ├── legacy.py └── util.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py └── training ├── __init__.py └── networks.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.pkl 3 | **/__pycache__/ 4 | 5 | /.venv/ 6 | /out/ 7 | 8 | /.envrc 9 | /TODO.md 10 | 11 | /scripts/ 12 | /datasets/FFHQ* 13 | 14 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.extraPaths": [ 3 | "stylegan2_ada" 4 | ] 5 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robust Unsupervised StyleGAN Image Restoration 2 | ### [[Arxiv]](https://arxiv.org/abs/2302.06733) [[Website]](https://lvsn.github.io/RobustUnsupervised/) 3 | 4 | Code for the paper `Robust Unsupervised StyleGAN Image Restoration` presented at CVPR 2023. 5 | 6 | ## Installation 7 | 8 | 1) First install the same environment as https://github.com/NVlabs/stylegan2-ada-pytorch.git. It is not essential for the custom cuda kernels to compile correctly, they just make things run ~30% faster. 9 | 10 | 2) Run `pip install tyro`. For running the evaluation you will also need to `pip install torchmetrics git+https://github.com/jwblangley/pytorch-fid.git`. 11 | 12 | 2) Download the pretrained StyleGAN model: 13 | ```bash 14 | wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl -O pretrained/ffhq.pkl 15 | ``` 16 | ## Restoring images 17 | 18 | To run the tasks presented in the paper, use: 19 | 20 | ```bash 21 | python run.py --dataset_path datasets/samples 22 | ``` 23 | 24 | Some sample images have already been provided in `datasets/samples`. 25 | 26 | ## Other datasets 27 | First, download a pretrained StyleGAN2 generator for your dataset (.pkl), and pass it's path to the `--pkl_path` option. 28 | If the resolution of your data is different from 1024 you also need to set it using the `--resolution` option. 29 | This resolution does not need to match the pretrained generator's resolution; for best results pick a high resolution generator even if your images are smaller. 30 | 31 | Finally, on datasets other than faces you may need to scale all learning rates up or down by a constant amount to compensate for the different scale of the latent space. For this you can use the CLI option `--global_lr_scale`. 32 | 33 | ## Restoring your own degradations 34 | Use the option `--tasks custom`, then find the following code in `run.py` and update it with your degradation function: 35 | 36 | ```python 37 | class YourDegradation: 38 | def degrade_ground_truth(self, x): 39 | """ 40 | The true degradation you are attempting to invert. 41 | This assumes you are testing against clean ground truth images. 42 | """ 43 | raise NotImplementedError 44 | 45 | def degrade_prediction(self, x): 46 | """ 47 | Differentiable approximation to the degradation in question. 48 | Can be identical to the true degradation if it is invertible. 49 | """ 50 | raise NotImplementedError 51 | ``` 52 | If you do not have access to ground truth images, you can open degraded images directly and make `degrade_ground_truth` an indentity function. 53 | 54 | ## Evaluation 55 | To run the full evaluation, use: 56 | ``` 57 | python -m benchmark.eval 58 | ``` 59 | Due to random variability the numbers may not match the paper exactly, but you should expect scores to be equal or better on average. For instance: 60 | ``` 61 | XL Upsampling: 21.5 (this repo) vs. 21.3 (paper) 62 | XL Denoising: 17.8 (this repo) vs. 17.9 (paper) 63 | XL Deartifacting: 16.7 (this repo) vs. 18.7 (paper) 64 | XL Inpainting: 14.0 (this repo) vs 15.0 (paper) 65 | ``` 66 | This codebase embeds the FID code from https://github.com/mseitzer/pytorch-fid, please consider citing them. 67 | -------------------------------------------------------------------------------- /benchmark/DiffJPEG/DiffJPEG.py: -------------------------------------------------------------------------------- 1 | # Pytorch 2 | import torch 3 | import torch.nn as nn 4 | 5 | # Local 6 | from modules import compress_jpeg, decompress_jpeg 7 | import utils 8 | from utils import diff_round, quality_to_factor 9 | import threading 10 | 11 | 12 | class DiffJPEG(nn.Module): 13 | def __init__(self, k, quantization_table, differentiable=True): 14 | """Initialize the DiffJPEG layer 15 | Inputs: 16 | height(int): Original image hieght 17 | width(int): Original image width 18 | differentiable(bool): If true uses custom differentiable 19 | rounding function, if false uses standrard torch.round 20 | quality(float): Quality factor for jpeg compression scheme. 21 | """ 22 | super().__init__() 23 | if differentiable: 24 | rounding_y = diff_round(k) 25 | rounding_c = diff_round(k) 26 | else: 27 | rounding_y = torch.round 28 | rounding_c = torch.round 29 | self.compress = compress_jpeg( 30 | rounding_y=rounding_y, 31 | rounding_c=rounding_c, 32 | quantization_table=quantization_table, 33 | ) 34 | self.decompress = decompress_jpeg(quantization_table=quantization_table) 35 | 36 | def parameters(self, recurse=False): 37 | return [] 38 | 39 | def forward(self, x): 40 | """ """ 41 | y, cb, cr = self.compress(x) 42 | recovered = self.decompress(y, cb, cr) 43 | return recovered 44 | -------------------------------------------------------------------------------- /benchmark/DiffJPEG/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Michael R Lomnitz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /benchmark/DiffJPEG/README.md: -------------------------------------------------------------------------------- 1 | # DiffJPEG: A PyTorch implementation 2 | 3 | This is a pytorch implementation of differentiable jpeg compression algorithm. This work is based on the discussion in this [paper](https://machine-learning-and-security.github.io/papers/mlsec17_paper_54.pdf). The work relies heavily on the tensorflow implementation in this [repository](https://github.com/rshin/differentiable-jpeg) 4 | 5 | ## Requirements 6 | - Pytorch 1.0.0 7 | - numpy 1.15.4 8 | 9 | ## Use 10 | 11 | DiffJPEG functions as a standard pytorch module/layer. To use, first import the layer and then initialize with the desired parameters: 12 | - differentaible(bool): If true uses custom differentiable rounding function, if false uses standrard torch.round 13 | - quality(float): Quality factor for jpeg compression scheme. 14 | 15 | ``` python 16 | from DiffJPEG import DiffJPEG 17 | jpeg = DiffJPEG(hieght=224, width=224, differentiable=True, quality=80) 18 | ``` 19 | 20 | ![image](./diffjpeg.png) 21 | -------------------------------------------------------------------------------- /benchmark/DiffJPEG/diffjpeg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yohan-pg/robust-unsupervised/afa536855abb253199898cb915ecb9c8e51dbda9/benchmark/DiffJPEG/diffjpeg.png -------------------------------------------------------------------------------- /benchmark/DiffJPEG/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # python3 2 | from .compression import compress_jpeg 3 | from .decompression import decompress_jpeg 4 | -------------------------------------------------------------------------------- /benchmark/DiffJPEG/modules/compression.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | import itertools 3 | import numpy as np 4 | 5 | # PyTorch 6 | import torch 7 | import torch.nn as nn 8 | 9 | # Local 10 | import benchmark.DiffJPEG.utils as utils 11 | 12 | 13 | class rgb_to_ycbcr_jpeg(nn.Module): 14 | """Converts RGB image to YCbCr 15 | Input: 16 | image(tensor): batch x 3 x height x width 17 | Outpput: 18 | result(tensor): batch x height x width x 3 19 | """ 20 | 21 | def __init__(self): 22 | super(rgb_to_ycbcr_jpeg, self).__init__() 23 | matrix = np.array( 24 | [ 25 | [0.299, 0.587, 0.114], 26 | [-0.168736, -0.331264, 0.5], 27 | [0.5, -0.418688, -0.081312], 28 | ], 29 | dtype=np.float32, 30 | ).T 31 | self.shift = nn.Parameter(torch.tensor([0.0, 128.0, 128.0])) 32 | # 33 | self.matrix = nn.Parameter(torch.from_numpy(matrix)) 34 | 35 | def forward(self, image): 36 | image = image.permute(0, 2, 3, 1) 37 | result = torch.tensordot(image.float(), self.matrix, dims=1) + self.shift 38 | result.view(image.shape) 39 | return result 40 | 41 | 42 | class chroma_subsampling(nn.Module): 43 | """ Chroma subsampling on CbCv channels 44 | Input: 45 | image(tensor): batch x height x width x 3 46 | Output: 47 | y(tensor): batch x height x width 48 | cb(tensor): batch x height/2 x width/2 49 | cr(tensor): batch x height/2 x width/2 50 | """ 51 | def __init__(self): 52 | super(chroma_subsampling, self).__init__() 53 | 54 | def forward(self, image): 55 | image_2 = image.permute(0, 3, 1, 2).clone() 56 | avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2), 57 | count_include_pad=False) 58 | cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1)) 59 | cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1)) 60 | cb = cb.permute(0, 2, 3, 1) 61 | cr = cr.permute(0, 2, 3, 1) 62 | return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) 63 | 64 | 65 | class block_splitting(nn.Module): 66 | """Splitting image into patches 67 | Input: 68 | image(tensor): batch x height x width 69 | Output: 70 | patch(tensor): batch x h*w/64 x h x w 71 | """ 72 | 73 | def __init__(self): 74 | super(block_splitting, self).__init__() 75 | self.k = 8 76 | 77 | def forward(self, image): 78 | height, width = image.shape[1:3] 79 | batch_size = image.shape[0] 80 | image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) 81 | image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) 82 | return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) 83 | 84 | 85 | class dct_8x8(nn.Module): 86 | """Discrete Cosine Transformation 87 | Input: 88 | image(tensor): batch x height x width 89 | Output: 90 | dcp(tensor): batch x height x width 91 | """ 92 | 93 | def __init__(self): 94 | super(dct_8x8, self).__init__() 95 | tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) 96 | for x, y, u, v in itertools.product(range(8), repeat=4): 97 | tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos( 98 | (2 * y + 1) * v * np.pi / 16 99 | ) 100 | alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7) 101 | # 102 | self.tensor = nn.parameter.Parameter(torch.from_numpy(tensor).float()) 103 | self.scale = nn.parameter.Parameter( 104 | torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() 105 | ) 106 | 107 | def forward(self, image): 108 | image = image - 128 109 | result = self.scale * torch.tensordot(image, self.tensor, dims=2) 110 | result.view(image.shape) 111 | return result 112 | 113 | 114 | class y_quantize(nn.Module): 115 | """JPEG Quantization for Y channel 116 | Input: 117 | image(tensor): batch x height x width 118 | rounding(function): rounding function to use 119 | factor(float): Degree of compression 120 | Output: 121 | image(tensor): batch x height x width 122 | """ 123 | 124 | def __init__(self, rounding, quantization_table): 125 | super(y_quantize, self).__init__() 126 | self.rounding = rounding 127 | self.y_table = torch.tensor(quantization_table[0]).reshape(8, 8).t() 128 | 129 | def forward(self, image): 130 | image = image.float() / self.y_table.to(image.device) 131 | return self.rounding(image) 132 | 133 | 134 | class c_quantize(nn.Module): 135 | """JPEG Quantization for CrCb channels 136 | Input: 137 | image(tensor): batch x height x width 138 | rounding(function): rounding function to use 139 | factor(float): Degree of compression 140 | Output: 141 | image(tensor): batch x height x width 142 | """ 143 | 144 | def __init__(self, rounding, quantization_table): 145 | super(c_quantize, self).__init__() 146 | self.rounding = rounding 147 | self.c_table = torch.tensor(quantization_table[1]).reshape(8, 8).t() 148 | 149 | def forward(self, image): 150 | image = image.float() / self.c_table.to(image.device) 151 | return self.rounding(image) 152 | 153 | 154 | class compress_jpeg(nn.Module): 155 | """Full JPEG compression algortihm 156 | Input: 157 | imgs(tensor): batch x 3 x height x width 158 | rounding(function): rounding function to use 159 | factor(float): Compression factor 160 | Ouput: 161 | compressed(dict(tensor)): batch x h*w/64 x 8 x 8 162 | """ 163 | 164 | def __init__(self, quantization_table, rounding_c, rounding_y): 165 | super(compress_jpeg, self).__init__() 166 | self.l1 = nn.Sequential(rgb_to_ycbcr_jpeg(), chroma_subsampling()) 167 | self.l2 = nn.Sequential(block_splitting(), dct_8x8()) 168 | self.c_quantize = c_quantize(rounding_c, quantization_table) 169 | self.y_quantize = y_quantize(rounding_y, quantization_table) 170 | 171 | def forward(self, image): 172 | y, cb, cr = self.l1(image * 255) 173 | components = {"y": y, "cb": cb, "cr": cr} 174 | for k in components.keys(): 175 | comp = self.l2(components[k]) 176 | if k in ("cb", "cr"): 177 | comp = self.c_quantize(comp) 178 | else: 179 | comp = self.y_quantize(comp) 180 | 181 | components[k] = comp 182 | 183 | return components["y"], components["cb"], components["cr"] 184 | -------------------------------------------------------------------------------- /benchmark/DiffJPEG/modules/decompression.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | import itertools 3 | import numpy as np 4 | import math 5 | 6 | # PyTorch 7 | import torch 8 | import torch.nn as nn 9 | 10 | # Local 11 | import benchmark.DiffJPEG.utils as utils 12 | import os 13 | 14 | class y_dequantize(nn.Module): 15 | """Dequantize Y channel 16 | Inputs: 17 | image(tensor): batch x height x width 18 | factor(float): compression factor 19 | Outputs: 20 | image(tensor): batch x height x width 21 | 22 | """ 23 | 24 | def __init__(self, quantization_table): 25 | super(y_dequantize, self).__init__() 26 | self.y_table = torch.tensor(quantization_table[0]).reshape(8, 8).t() 27 | 28 | def forward(self, image): 29 | return image * self.y_table.to(image.device) 30 | 31 | 32 | class c_dequantize(nn.Module): 33 | """Dequantize CbCr channel 34 | Inputs: 35 | image(tensor): batch x height x width 36 | factor(float): compression factor 37 | Outputs: 38 | image(tensor): batch x height x width 39 | 40 | """ 41 | 42 | def __init__(self, quantization_table): 43 | super(c_dequantize, self).__init__() 44 | self.c_table = torch.tensor(quantization_table[1]).reshape(8, 8).t() 45 | 46 | def forward(self, image): 47 | return image * self.c_table.to(image.device) 48 | 49 | 50 | class idct_8x8(nn.Module): 51 | """Inverse discrete Cosine Transformation 52 | Input: 53 | dcp(tensor): batch x height x width 54 | Output: 55 | image(tensor): batch x height x width 56 | """ 57 | 58 | def __init__(self): 59 | super(idct_8x8, self).__init__() 60 | alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7) 61 | self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) 62 | tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) 63 | for x, y, u, v in itertools.product(range(8), repeat=4): 64 | tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos( 65 | (2 * v + 1) * y * np.pi / 16 66 | ) 67 | self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) 68 | 69 | def forward(self, image): 70 | 71 | image = image * self.alpha 72 | result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 73 | result.reshape(image.shape) 74 | return result 75 | 76 | 77 | class block_merging(nn.Module): 78 | """Merge pathces into image 79 | Inputs: 80 | patches(tensor) batch x height*width/64, height x width 81 | height(int) 82 | width(int) 83 | Output: 84 | image(tensor): batch x height x width 85 | """ 86 | 87 | def __init__(self): 88 | super(block_merging, self).__init__() 89 | 90 | def forward(self, patches, height, width): 91 | k = 8 92 | batch_size = patches.shape[0] 93 | image_reshaped = patches.reshape(batch_size, height // k, width // k, k, k) 94 | image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) 95 | return image_transposed.contiguous().reshape(batch_size, height, width) 96 | 97 | 98 | class chroma_upsampling(nn.Module): 99 | """Upsample chroma layers 100 | Input: 101 | y(tensor): y channel image 102 | cb(tensor): cb channel 103 | cr(tensor): cr channel 104 | Ouput: 105 | image(tensor): batch x height x width x 3 106 | """ 107 | 108 | def __init__(self): 109 | super(chroma_upsampling, self).__init__() 110 | 111 | def forward(self, y, cb, cr): 112 | def repeat(x, k=2): 113 | height, width = x.shape[1:3] 114 | x = x.unsqueeze(-1) 115 | x = x.repeat(1, 1, k, k) 116 | x = x.reshape(-1, height * k, width * k) 117 | return x 118 | 119 | cb = repeat(cb) 120 | cr = repeat(cr) 121 | 122 | return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) 123 | 124 | 125 | class ycbcr_to_rgb_jpeg(nn.Module): 126 | """Converts YCbCr image to RGB JPEG 127 | Input: 128 | image(tensor): batch x height x width x 3 129 | Outpput: 130 | result(tensor): batch x 3 x height x width 131 | """ 132 | 133 | def __init__(self): 134 | super(ycbcr_to_rgb_jpeg, self).__init__() 135 | 136 | matrix = np.array( 137 | [[1.0, 0.0, 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], 138 | dtype=np.float32, 139 | ).T 140 | self.shift = nn.Parameter(torch.tensor([0, -128.0, -128.0])) 141 | self.matrix = nn.Parameter(torch.from_numpy(matrix)) 142 | 143 | def forward(self, image): 144 | result = torch.tensordot(image + self.shift, self.matrix, dims=1) 145 | # result = torch.from_numpy(result) 146 | result.reshape(image.shape) 147 | return result.permute(0, 3, 1, 2) 148 | 149 | 150 | class decompress_jpeg(nn.Module): 151 | """Full JPEG decompression algortihm 152 | Input: 153 | compressed(dict(tensor)): batch x h*w/64 x 8 x 8 154 | rounding(function): rounding function to use 155 | factor(float): Compression factor 156 | Ouput: 157 | image(tensor): batch x 3 x height x width 158 | """ 159 | 160 | def __init__(self, quantization_table): 161 | super(decompress_jpeg, self).__init__() 162 | self.c_dequantize = c_dequantize(quantization_table) 163 | self.y_dequantize = y_dequantize(quantization_table) 164 | self.idct = idct_8x8() 165 | self.merging = block_merging() 166 | self.chroma = chroma_upsampling() 167 | self.colors = ycbcr_to_rgb_jpeg() 168 | 169 | def forward(self, y, cb, cr): 170 | components = {"y": y, "cb": cb, "cr": cr} 171 | full_height, full_width = y.shape[-2] * math.isqrt(y.shape[1]), y.shape[-1] * math.isqrt(y.shape[1]) 172 | for k in components.keys(): 173 | if k in ("cb", "cr"): 174 | comp = self.c_dequantize(components[k]) 175 | height = full_height // 2 176 | width = full_width // 2 177 | else: 178 | comp = self.y_dequantize(components[k]) 179 | height = full_height 180 | width = full_width 181 | comp = self.idct(comp) 182 | components[k] = self.merging(comp, height, width) 183 | 184 | image = self.chroma(components["y"], components["cb"], components["cr"]) 185 | image = self.colors(image) 186 | 187 | if "KEEP_DIFF_JPEG_CLAMP" in os.environ: 188 | image = torch.min( 189 | 255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image) 190 | ) 191 | return image / 255 192 | -------------------------------------------------------------------------------- /benchmark/DiffJPEG/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.15.4 2 | torch==1.0.0 3 | -------------------------------------------------------------------------------- /benchmark/DiffJPEG/utils.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | import numpy as np 3 | 4 | # PyTorch 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | # static const unsigned int std_luminance_quant_tbl[DCTSIZE2] = { 10 | # 16, 11, 10, 16, 24, 40, 51, 61, 11 | # 12, 12, 14, 19, 26, 58, 60, 55, 12 | # 14, 13, 16, 24, 40, 57, 69, 56, 13 | # 14, 17, 22, 29, 51, 87, 80, 62, 14 | # 18, 22, 37, 56, 68, 109, 103, 77, 15 | # 24, 35, 55, 64, 81, 104, 113, 92, 16 | # 49, 64, 78, 87, 103, 121, 120, 101, 17 | # 72, 92, 95, 98, 112, 100, 103, 99 18 | # }; 19 | 20 | y_table = np.array( 21 | [ 22 | [16, 11, 10, 16, 24, 40, 51, 61], 23 | [12, 12, 14, 19, 26, 58, 60, 55], 24 | [14, 13, 16, 24, 40, 57, 69, 56], 25 | [14, 17, 22, 29, 51, 87, 80, 62], 26 | [18, 22, 37, 56, 68, 109, 103, 77], 27 | [24, 35, 55, 64, 81, 104, 113, 92], 28 | [49, 64, 78, 87, 103, 121, 120, 101], 29 | [72, 92, 95, 98, 112, 100, 103, 99], 30 | ], 31 | dtype=np.float32, 32 | ).T 33 | 34 | 35 | y_table = nn.Parameter(torch.from_numpy(y_table)) 36 | # 37 | c_table = np.empty((8, 8), dtype=np.float32) 38 | c_table.fill(99) 39 | c_table[:4, :4] = np.array( 40 | [[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]] 41 | ).T 42 | c_table = nn.Parameter(torch.from_numpy(c_table)) 43 | 44 | 45 | # static const unsigned int std_chrominance_quant_tbl[DCTSIZE2] = { 46 | # 16, 18, 24, 47, 99, 99, 99, 99, 47 | # 18, 21, 26, 66, 99, 99, 99, 99, 48 | # 24, 26, 56, 99, 99, 99, 99, 99, 49 | # 47, 66, 99, 99, 99, 99, 99, 99, 50 | # 99, 99, 99, 99, 99, 99, 99, 99, 51 | # 99, 99, 99, 99, 99, 99, 99, 99, 52 | # 99, 99, 99, 99, 99, 99, 99, 99, 53 | # 99, 99, 99, 99, 99, 99, 99, 99 54 | # }; 55 | 56 | 57 | class SurrogateDiffRound(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, x, k): 60 | ctx.save_for_backward(x, k) 61 | return torch.round(x) 62 | 63 | @staticmethod 64 | def backward(ctx, grad_y): 65 | (x, k) = ctx.saved_tensors 66 | with torch.enable_grad(): 67 | approx = torch.round(x) + (x - torch.round(x))**3 68 | return ( 69 | torch.autograd.grad( 70 | x.lerp(approx, k.to(x.device)), x, grad_y 71 | )[0], 72 | None, 73 | ) 74 | 75 | 76 | def diff_round(k): 77 | """Differentiable rounding function 78 | Input: 79 | x(tensor) 80 | Output: 81 | x(tensor) 82 | """ 83 | return lambda x: SurrogateDiffRound.apply(x, torch.tensor(k)) 84 | 85 | 86 | def quality_to_factor(quality): 87 | """Calculate factor corresponding to quality 88 | Input: 89 | quality(float): Quality for jpeg compression 90 | Output: 91 | factor(float): Compression factor 92 | """ 93 | if quality < 50: 94 | quality = 5000.0 / quality 95 | else: 96 | quality = 200.0 - quality * 2 97 | return quality / 100.0 98 | -------------------------------------------------------------------------------- /benchmark/FFHQ-X_crops128_ncrops1000.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yohan-pg/robust-unsupervised/afa536855abb253199898cb915ecb9c8e51dbda9/benchmark/FFHQ-X_crops128_ncrops1000.npz -------------------------------------------------------------------------------- /benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | from .tasks import * 2 | -------------------------------------------------------------------------------- /benchmark/config.py: -------------------------------------------------------------------------------- 1 | # You do not need to edit this file 2 | 3 | resolution: int = 512 4 | batch_size: int = 1 5 | -------------------------------------------------------------------------------- /benchmark/degradations.py: -------------------------------------------------------------------------------- 1 | import benchmark.config as config 2 | 3 | from .prelude import * 4 | 5 | sys.path.append(os.path.dirname(__file__) + "/" + "DiffJPEG") 6 | from benchmark.DiffJPEG.DiffJPEG import DiffJPEG 7 | 8 | import torchvision.transforms.functional as TF 9 | from PIL import JpegImagePlugin 10 | import tempfile 11 | import random 12 | import cv2 13 | import os 14 | 15 | 16 | TMP_SAVE_FILEPATH = tempfile.mkstemp()[1] 17 | 18 | 19 | class Degradation(nn.Module): 20 | seed = 2022 21 | mask = None 22 | 23 | def __init__(self): 24 | super().__init__() 25 | self.seed += 1 26 | 27 | def _true_degradation(self, ground_truth): 28 | """ 29 | Applies the true degradation, which may be non-differentiable. 30 | """ 31 | raise NotImplementedError 32 | 33 | def degrade_prediction(self, pred): 34 | """ 35 | Applies the differentiable approximation to the given degradation. 36 | """ 37 | raise NotImplementedError 38 | 39 | @torch.no_grad() 40 | def degrade_ground_truth(self, ground_truth, save_path=None): 41 | """ 42 | Applies the true, potentially undifferentiable degradation to a ground truth image. 43 | As a sanity check, the image is always saved to a file. 44 | """ 45 | torch.manual_seed(self.seed) 46 | if save_path is None: 47 | save_path = TMP_SAVE_FILEPATH + ".png" 48 | 49 | degraded_target = self._true_degradation(ground_truth.clamp(min=0, max=1)) 50 | result = cycle_to_file(degraded_target, save_path) 51 | torch.seed() 52 | return result 53 | 54 | def forward(self, x): 55 | return self.degrade_prediction(x) 56 | 57 | 58 | def cycle_to_file(x, save_path: str): 59 | """ 60 | Saves an image to a file and reads it back immediately. 61 | This is used to propely account for quantization and clamping, 62 | ensuring that the differentiable approx. does not make false assumptions. 63 | (For instance, clamping must be taken account for when adding strong noise) 64 | """ 65 | assert x.shape[0] == 1 # batching is not supported yet 66 | TF.to_pil_image(x.squeeze(0).clamp(0, 1)).save(save_path) 67 | return TF.to_tensor(PIL.Image.open(save_path)).unsqueeze(0).to(x.device) 68 | 69 | 70 | class Downsample(Degradation): 71 | def __init__(self, downsampling_factor: int): 72 | super().__init__() 73 | self.downsampling_factor = downsampling_factor 74 | self.filter = random.choice( 75 | [ 76 | PIL.Image.BILINEAR, 77 | PIL.Image.BICUBIC, 78 | PIL.Image.LANCZOS, 79 | ] 80 | ) 81 | 82 | def degrade_prediction(self, x): 83 | return F.avg_pool2d(x, self.downsampling_factor) 84 | 85 | def _true_degradation(self, x): 86 | assert x.shape[0] == 1, "Batching not yet supported" 87 | image = TF.to_pil_image(x.squeeze(0)) 88 | res = math.floor(x.shape[-1] // self.downsampling_factor) 89 | 90 | image = image.resize( 91 | (res, res), 92 | self.filter, 93 | ) 94 | path = TMP_SAVE_FILEPATH + ".png" 95 | image.save(path) 96 | return TF.to_tensor(PIL.Image.open(path)).unsqueeze(0).to(x.device) 97 | 98 | 99 | class AddNoise(Degradation): 100 | k = 2.0 101 | eps = 1e-3 102 | 103 | def __init__( 104 | self, 105 | noise_amount: float, 106 | ): 107 | super().__init__() 108 | self.noise_amount = noise_amount 109 | self.clamp = True 110 | self.seed += 1 111 | 112 | def degrade_prediction(self, x): 113 | # pre-clamp the generator output, to match the fact that the (artificial) noise is 114 | # added to ground truth images with pixel values in [0, 1] 115 | # if working with real sources of noise this can be omitted 116 | x = self.differentiable_clamp(x) 117 | 118 | num_photons, bernoulli_p = self.noise_amount 119 | 120 | # Approximate poisson noise with a gaussian 121 | if num_photons > 0: 122 | noise = torch.randn(1, 3, x.shape[2], x.shape[3], device=x.device) 123 | lambd = x * num_photons 124 | mu = lambd - 1 / 2 125 | sigma = (lambd + self.eps).sqrt() 126 | y = (mu + sigma * noise) / num_photons 127 | else: 128 | y = x 129 | 130 | # Bernoulli noise 131 | y = y * (torch.rand_like(y)[:, 0:1] > bernoulli_p).float() 132 | 133 | return self.differentiable_clamp(y) 134 | 135 | @torch.no_grad() 136 | def _true_degradation(self, x): 137 | num_photons, bernoulli_p = self.noise_amount 138 | 139 | # Add poisson noise 140 | if num_photons > 0: 141 | y = torch.poisson(x * num_photons) / num_photons 142 | else: 143 | y = x 144 | 145 | # Bernoulli noise 146 | y = y * (torch.rand_like(y)[:, 0:1] > bernoulli_p).float() 147 | 148 | return y.clamp(0.0, 1.0) 149 | 150 | class _ClampWithSurrogateGradient(torch.autograd.Function): 151 | @staticmethod 152 | def forward(ctx, x): 153 | ctx.save_for_backward(x) 154 | return x.clamp(0.0, 1.0) 155 | 156 | @staticmethod 157 | def backward(ctx, grad_y): 158 | (x,) = ctx.saved_tensors 159 | with torch.enable_grad(): 160 | return ( 161 | torch.autograd.grad( 162 | torch.sigmoid(AddNoise.k * (x - 0.5)), x, grad_y 163 | )[0], 164 | None, 165 | ) 166 | 167 | differentiable_clamp = _ClampWithSurrogateGradient.apply 168 | 169 | 170 | class CenterCrop(Degradation): 171 | def __init__(self, *args): 172 | super().__init__() 173 | 174 | def degrade_prediction(self, x): 175 | result = torch.zeros_like(x) 176 | result[:, :, 400:600, 400:600] = x[:, :, 400:600, 400:600] 177 | return result 178 | 179 | def _true_degradation(self, x): 180 | return self.degrade_prediction(x) 181 | 182 | 183 | class CompressJPEG(Degradation): 184 | k = 0.8 185 | 186 | def __init__(self, quality: int): 187 | super().__init__() 188 | self.quality = quality 189 | 190 | # Possible subsampling values are 0, 1 and 2 that correspond to 4:4:4, 4:2:2 and 4:2:0. 191 | x_img = TF.to_pil_image(torch.randn(3, config.resolution, config.resolution)) 192 | 193 | # Extract quantization table 194 | path = TMP_SAVE_FILEPATH + ".jpg" 195 | x_img.save( 196 | path, 197 | quality=self.quality, 198 | ) 199 | compressed_image = PIL.Image.open(path) 200 | table = compressed_image.quantization # type: ignore 201 | assert JpegImagePlugin.get_sampling(compressed_image) == 2 202 | 203 | self.to_jpeg = DiffJPEG( 204 | self.k, 205 | differentiable=True, 206 | quantization_table=table, 207 | ).cuda() # type: ignore 208 | 209 | def parameters(self, recurse=False): 210 | # This is important, it prevents the optimization of DiffJPEG's parameters 211 | return [] 212 | 213 | def degrade_prediction(self, x): 214 | return self.to_jpeg(x) 215 | 216 | def _true_degradation(self, x): 217 | if "CHEAT_DEARTIFACT" in os.environ: 218 | return self.degrade_prediction(x).detach() 219 | else: 220 | assert x.shape[0] == 1, "Batching not yet supported" 221 | path = TMP_SAVE_FILEPATH + ".jpg" 222 | TF.to_pil_image(x.squeeze(0)).save(path, quality=self.quality) 223 | return TF.to_tensor(PIL.Image.open(path)).unsqueeze(0).to(x.device) 224 | 225 | 226 | class MaskRandomly(Degradation): 227 | def __init__(self, num_strokes: int): 228 | super().__init__() 229 | self.num_strokes = num_strokes 230 | torch.manual_seed(self.seed) 231 | self.mask = self._generate_mask() 232 | torch.seed() 233 | 234 | def _generate_mask(self): 235 | image_height = config.resolution * 4 236 | image_width = config.resolution * 4 237 | brush_width = int(config.resolution * 0.08) * 4 238 | 239 | mask = np.zeros((image_height, image_width)) 240 | 241 | def sample(): 242 | w = image_width - 1 243 | h = image_height - 1 244 | return random.choice( 245 | [ 246 | random.randint(0, w // 3), 247 | random.randint(2 * w // 3, w), 248 | ] 249 | ), random.choice( 250 | [ 251 | random.randint(0, h // 3), 252 | random.randint(2 * h // 3, h), 253 | ] 254 | ) 255 | 256 | for _ in range(self.num_strokes): 257 | start_x, start_y = sample() 258 | end_x, end_y = sample() 259 | mask = cv2.line( 260 | mask, 261 | (start_x, start_y), 262 | (end_x, end_y), 263 | color=1, 264 | thickness=brush_width, 265 | ) 266 | mask = cv2.circle(mask, (start_x, start_y), int(brush_width / 2), 1) 267 | 268 | return ( 269 | torch.from_numpy(1.0 - cv2.pyrDown(cv2.pyrDown(mask))) 270 | .float() 271 | .cuda()[None, None] 272 | ) 273 | 274 | def _true_degradation(self, x): 275 | return x * F.interpolate( 276 | self.mask, x.shape[-1], mode="bicubic", align_corners=False 277 | ) 278 | 279 | def degrade_prediction(self, x): 280 | return self._true_degradation(x) 281 | 282 | 283 | class IdentityDegradation(Degradation): 284 | def __init__(self, *args): 285 | super().__init__() 286 | 287 | def degrade_prediction(self, x): 288 | return x 289 | 290 | def _true_degradation(self, x): 291 | return x 292 | 293 | 294 | class ComposedDegradation(Degradation): 295 | def __init__( 296 | self, 297 | degradations: List[Degradation], 298 | ): 299 | super().__init__() 300 | self.degradations = nn.ModuleList(degradations) 301 | 302 | @property 303 | def mask(self): 304 | return self.degradations[-1].mask 305 | 306 | def parameters(self, recurse=False): 307 | return sum([list(deg.parameters()) for deg in self.degradations], []) 308 | 309 | def degrade_prediction(self, x): 310 | for deg in self.degradations: 311 | x = deg.degrade_prediction(x) 312 | return x 313 | 314 | def _true_degradation(self, x): 315 | for deg in self.degradations: 316 | x = deg._true_degradation(x) 317 | return x 318 | 319 | def degrade_ground_truth(self, x, save_path=None): 320 | for deg in self.degradations: 321 | # Re-saves the image between each degradation, overkill but OK 322 | x = deg.degrade_ground_truth(x, save_path=save_path) 323 | return x 324 | 325 | 326 | class ResizePrediction(Degradation): 327 | # This is a hack used to match the prediction resolution to the target resolutions 328 | 329 | def __init__(self, size: int): 330 | super().__init__() 331 | self.size = size 332 | 333 | def degrade_prediction(self, x): 334 | return self._true_degradation(x) 335 | 336 | def _true_degradation(self, x): 337 | return F.interpolate( 338 | x, 339 | size=self.size, 340 | mode="area", 341 | ) 342 | 343 | 344 | def adapt_to_resolution(x, res: int): 345 | return ComposedDegradation([ResizePrediction(res), x]) 346 | -------------------------------------------------------------------------------- /benchmark/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from torchvision.io import read_image, write_png 4 | import torchvision.transforms as T 5 | import glob 6 | from typing import * # type: ignore 7 | import shutil 8 | import subprocess 9 | import json 10 | from joblib import Parallel, delayed 11 | import torchmetrics 12 | import torch.nn.functional as F 13 | import torch 14 | import sys 15 | 16 | DATASET_NAME = "FFHQ-X" 17 | DATASET_PATH = "datasets" 18 | 19 | CROP_RES = 128 20 | CROP_RES_LABEL = "" if CROP_RES == 256 else str(CROP_RES) # legacy 21 | 22 | CROP_NUM = 1000 23 | CROP_NUM_LABEL = '_ncrops' + str(CROP_NUM) if CROP_NUM != 250 else '' 24 | 25 | def globr(pattern): 26 | paths = glob.glob(pattern, recursive=True) 27 | assert len(paths) > 0, f"{pattern} matches nothing" 28 | return paths 29 | 30 | 31 | def save_image(x, path): 32 | os.makedirs(os.path.dirname(path), exist_ok=True) 33 | write_png(x, path) 34 | 35 | 36 | def make_crops( 37 | image_paths: List[str], 38 | out_path: str, 39 | num_crops_per_image: int, 40 | ): 41 | cropping = T.RandomCrop(CROP_RES) 42 | print("Producing crops...") 43 | 44 | if os.path.exists(out_path): 45 | shutil.rmtree(out_path, ignore_errors=True) 46 | 47 | os.makedirs(out_path) 48 | 49 | @delayed 50 | def process_image(i, im_path): 51 | pred = read_image(im_path) 52 | for k in range(num_crops_per_image): 53 | save_image(cropping(pred), f"{out_path}/{i:04d}/{k:04}.png") 54 | 55 | Parallel(n_jobs=16, verbose=10)(process_image(*x) for x in enumerate(image_paths)) 56 | 57 | 58 | def crop_dataset(): 59 | if True: 60 | make_crops( 61 | globr(f"{DATASET_PATH}/{DATASET_NAME}/**/*.png"), 62 | f"datasets/{DATASET_NAME}_crops{CROP_RES}{CROP_NUM_LABEL}", 63 | 10 if "DRY_RUN" in os.environ else CROP_NUM, 64 | ) 65 | if True: 66 | print("Evaluating FID...") 67 | os.system( 68 | f"python -m pytorch_fid --batch-size 50 --save-stats datasets/{DATASET_NAME}_crops{CROP_RES} benchmark/FFHQ-X_crops128_ncrops1000.npz" 69 | ) 70 | 71 | 72 | import torchmetrics 73 | import torchmetrics.image.lpip as lpips 74 | 75 | 76 | def accronym(metric): 77 | return "".join(x for x in metric.__class__.__name__ if not x.islower()) 78 | 79 | 80 | def replace(str, from_part, to_part): 81 | assert from_part in str 82 | return str.replace(from_part, to_part) 83 | 84 | 85 | def eval_experiment( 86 | expr_path: str, 87 | suffixes: List[str], 88 | distance_metrics=[ 89 | torchmetrics.PeakSignalNoiseRatio(data_range=2.0).cuda(), 90 | lpips.LearnedPerceptualImagePatchSimilarity(net_type="vgg").cuda(), 91 | ], 92 | ): 93 | for suffix in suffixes: 94 | if COMPUTE_FID := True: 95 | make_crops( 96 | globr(f"{expr_path}/**/pred{suffix}.png"), 97 | f"{expr_path}/crops{CROP_RES_LABEL}{suffix}", 98 | 10 if "DRY_RUN" in os.environ else CROP_NUM, 99 | ) 100 | result = subprocess.check_output( 101 | f"python -m pytorch_fid benchmark/FFHQ-X_crops128_ncrops1000.npz {expr_path}/crops{CROP_RES_LABEL}{suffix}".split(" ") 102 | ) 103 | fid_score = float(result.decode("utf8").strip().replace("FID: ", "")) 104 | json.dump( 105 | fid_score, 106 | open( 107 | f"{expr_path}/fid{suffix.replace('/', '_')}{'.dry_run' if 'DRY_RUN' in os.environ else ''}{CROP_RES}{CROP_NUM_LABEL}.json", 108 | "w", 109 | ), 110 | ) 111 | 112 | if COMPUTE_DISTANCE_METRICS := True: 113 | degraded_scores = {accronym(metric): [] for metric in distance_metrics} 114 | ground_truth_scores = {accronym(metric): [] for metric in distance_metrics} 115 | 116 | for im_path in globr(f"{expr_path}/inversions/**/pred{suffix}.png"): 117 | def imopen(x): 118 | return (read_image(x).unsqueeze(0).float() / 255.0) * 2.0 - 1.0 119 | 120 | pred = imopen(im_path) 121 | degraded_pred = imopen(replace(im_path, f"pred{suffix}", f"degraded_pred{suffix}")) 122 | target = imopen(replace(im_path, f"pred{suffix}", "target")) 123 | ground_truth = imopen(replace(im_path, f"pred{suffix}", "ground_truth")) 124 | 125 | for metric in distance_metrics: 126 | degraded_scores[accronym(metric)].append( 127 | metric(degraded_pred.cuda(), target.cuda()).item() 128 | ) 129 | ground_truth_scores[accronym(metric)].append( 130 | metric(pred.cuda(), ground_truth.cuda()).item() 131 | ) 132 | 133 | if "DRY_RUN" in os.environ: 134 | break 135 | 136 | json.dump( 137 | { 138 | name: torch.tensor(scores).mean().item() 139 | for name, scores in degraded_scores.items() 140 | }, 141 | open( 142 | f"{expr_path}/degraded_scores{suffix.replace('/', '_')}{'.dry_run' if 'DRY_RUN' in os.environ else ''}.json", 143 | "w", 144 | ), 145 | ) 146 | json.dump( 147 | gtscores := { 148 | name: torch.tensor(scores).mean().item() 149 | for name, scores in ground_truth_scores.items() 150 | }, 151 | open( 152 | f"{expr_path}/ground_truth_scores{suffix.replace('/', '_')}{'.dry_run' if 'DRY_RUN' in os.environ else ''}.json", 153 | "w", 154 | ), 155 | ) 156 | print(gtscores["LPIPS"]) 157 | 158 | 159 | def eval_all_experiments( 160 | path: str, 161 | suffixes: List[str], 162 | ): 163 | for path in globr(f"{path}/**/inversions/"): 164 | print("👉", path) 165 | expr_path = path.split("/inversions")[0] 166 | eval_experiment(expr_path, suffixes) 167 | 168 | 169 | if __name__ == "__main__": 170 | import sys 171 | breakpoint() 172 | eval_all_experiments(sys.argv[1] + "/*", ["_W++"]) 173 | -------------------------------------------------------------------------------- /benchmark/prelude.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import functools 5 | from functools import partial 6 | from time import perf_counter 7 | import sys 8 | import torch.optim as optim 9 | import tqdm 10 | import click 11 | import dataclasses 12 | import numpy as np 13 | import PIL.Image 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import math 18 | from xml.dom import minidom 19 | import shutil 20 | import itertools 21 | from warnings import warn 22 | 23 | from torchvision.utils import save_image, make_grid 24 | 25 | from abc import ABC, abstractmethod, abstractstaticmethod, abstractclassmethod 26 | from dataclasses import dataclass, field 27 | 28 | from typing import Optional, Type, List, final, Tuple, Callable, Iterator, Iterable, Dict, ClassVar, Union, Any 29 | 30 | from torchvision.io import write_video 31 | -------------------------------------------------------------------------------- /benchmark/tasks.py: -------------------------------------------------------------------------------- 1 | from .degradations import * 2 | import itertools 3 | 4 | task_names = ["upsampling", "denoising", "deartifacting", "inpainting"] 5 | task_levels = ["XL", "L", "M", "S", "XS"] 6 | 7 | degradation_types = { 8 | "upsampling": Downsample, 9 | "inpainting": MaskRandomly, 10 | "denoising": AddNoise, 11 | "deartifacting": CompressJPEG, 12 | } 13 | 14 | degradation_levels = { 15 | "upsampling": { 16 | "XL": 32, 17 | "L": 16, 18 | "M": 8, 19 | "S": 4, 20 | "XS": 2, 21 | # 22 | 2: 8, 23 | 3: 8, 24 | 4: 8, 25 | }, 26 | "inpainting": { 27 | "XL": 17, 28 | "L": 13, 29 | "M": 9, 30 | "S": 5, 31 | "XS": 1, 32 | # 33 | 2: 9, 34 | 3: 9, 35 | 4: 9, 36 | }, 37 | "denoising": { 38 | "XL": (6, 0.64), 39 | "L": (12, 0.32), 40 | "M": (24, 0.16), 41 | "S": (48, 0.08), 42 | "XS": (96, 0.04), 43 | # 44 | 2: (24, 0.16), 45 | 3: (24, 0.16), 46 | 4: (24, 0.16), 47 | }, 48 | "deartifacting": { 49 | "XL": 6, 50 | "L": 9, 51 | "M": 12, 52 | "S": 15, 53 | "XS": 18, 54 | # 55 | 2: 12, 56 | 3: 12, 57 | 4: 12, 58 | }, 59 | } 60 | 61 | #### 62 | 63 | 64 | @dataclass 65 | class Task: 66 | name: str 67 | category: str 68 | level: str 69 | constructor: Type[Degradation] 70 | arg: Any 71 | 72 | def init_degradation(self): 73 | return ComposedDegradation( 74 | [ResizePrediction(config.resolution), self.constructor(self.arg)] 75 | ) 76 | 77 | 78 | def get_task(name: str, level: str): 79 | return Task( 80 | name, 81 | "single_tasks", 82 | level, 83 | degradation_types[name], 84 | degradation_levels[name][level], 85 | ) 86 | 87 | 88 | #### 89 | 90 | single_tasks: List[Task] = [] 91 | for level in task_levels: 92 | for name in task_names: 93 | single_tasks.append(get_task(name, level)) 94 | 95 | #### 96 | 97 | 98 | def init_composed(level: int): 99 | return lambda included_tasks: ComposedDegradation( 100 | [ 101 | degradation_types[name](degradation_levels[name][level]) 102 | for name in task_names 103 | if name in included_tasks 104 | ] 105 | ) 106 | 107 | 108 | full_composed_task = Task("UNAP", "composed_tasks", 4, init_composed(4), task_names) 109 | 110 | #### 111 | 112 | initials = { 113 | "upsampling": "U", 114 | "denoising": "N", 115 | "deartifacting": "A", 116 | "inpainting": "P", 117 | } 118 | 119 | composed_tasks: List[Task] = [] 120 | 121 | for k in range(2, len(task_names) + 1): 122 | for task_names_subseq in itertools.combinations(task_names, k): 123 | composed_tasks.append( 124 | Task( 125 | "".join([initials[task_name] for task_name in task_names_subseq]), 126 | "composed_tasks", 127 | k, 128 | init_composed(k), 129 | task_names_subseq, 130 | ) 131 | ) 132 | 133 | all_tasks: List[Task] = single_tasks + composed_tasks 134 | 135 | extreme_tasks = [] 136 | extreme_tasks += [get_task(name, "XL") for name in task_names] 137 | extreme_tasks += [full_composed_task] 138 | extreme_tasks += [get_task(name, "XS") for name in task_names] 139 | 140 | 141 | ### 142 | 143 | uncropping_task = Task("uncropping", "uncropping", 1, CenterCrop, None) 144 | identity_task = Task("identity", "identity", 1, IdentityDegradation, None) 145 | -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | from dataclasses import dataclass 3 | from typing import * 4 | 5 | import sys 6 | sys.path.append("stylegan2_ada") 7 | 8 | 9 | @dataclass 10 | class Config: 11 | name: str = f"restored_samples" 12 | "A name used to group log files." 13 | 14 | pkl_path: str = "pretrained_networks/ffhq.pkl" 15 | "The location of the pretrained StyleGAN." 16 | 17 | dataset_path: str = "datasets/FFHQ-X" 18 | "The location of the images to process." 19 | 20 | resolution: int = 1024 21 | "The resolution of your images. Images which are smaller or larger will be resized." 22 | 23 | global_lr_scale: float = 1.0 24 | "A global factor which scales up and down all learning rates. This may need adjustment for datasets other than faces." 25 | 26 | tasks: Literal["all", "single", "composed", "custom"] = "all" 27 | "Selects which tasks to run." 28 | 29 | 30 | def parse_config() -> Config: 31 | return tyro.cli(Config) 32 | -------------------------------------------------------------------------------- /datasets/samples/sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yohan-pg/robust-unsupervised/afa536855abb253199898cb915ecb9c8e51dbda9/datasets/samples/sample_1.png -------------------------------------------------------------------------------- /datasets/samples/sample_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yohan-pg/robust-unsupervised/afa536855abb253199898cb915ecb9c8e51dbda9/datasets/samples/sample_2.png -------------------------------------------------------------------------------- /pretrained_networks/put-pretrained-models-here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yohan-pg/robust-unsupervised/afa536855abb253199898cb915ecb9c8e51dbda9/pretrained_networks/put-pretrained-models-here -------------------------------------------------------------------------------- /pytorch_fid/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /pytorch_fid/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.1' -------------------------------------------------------------------------------- /pytorch_fid/__main__.py: -------------------------------------------------------------------------------- 1 | import pytorch_fid.fid_score 2 | 3 | pytorch_fid.fid_score.main() 4 | -------------------------------------------------------------------------------- /pytorch_fid/fid_score.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is a slightly modified version of https://github.com/mseitzer/pytorch-fid/tree/master, changed to allow for recursive globs and saving precomputed stats. 3 | 4 | -------------------------- 5 | 6 | Calculates the Frechet Inception Distance (FID) to evalulate GANs 7 | 8 | The FID metric calculates the distance between two distributions of images. 9 | Typically, we have summary statistics (mean & covariance matrix) of one 10 | of these distributions, while the 2nd distribution is given by a GAN. 11 | 12 | When run as a stand-alone program, it compares the distribution of 13 | images that are stored as PNG/JPEG at a specified location with a 14 | distribution given by summary statistics (in pickle format). 15 | 16 | The FID is calculated by assuming that X_1 and X_2 are the activations of 17 | the pool_3 layer of the inception net for generated samples and real world 18 | samples respectively. 19 | 20 | See --help to see further details. 21 | 22 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 23 | of Tensorflow 24 | 25 | Copyright 2018 Institute of Bioinformatics, JKU Linz 26 | 27 | Licensed under the Apache License, Version 2.0 (the "License"); 28 | you may not use this file except in compliance with the License. 29 | You may obtain a copy of the License at 30 | 31 | http://www.apache.org/licenses/LICENSE-2.0 32 | 33 | Unless required by applicable law or agreed to in writing, software 34 | distributed under the License is distributed on an "AS IS" BASIS, 35 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 36 | See the License for the specific language governing permissions and 37 | limitations under the License. 38 | """ 39 | import os 40 | import pathlib 41 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 42 | 43 | import numpy as np 44 | import torch 45 | import torchvision.transforms as TF 46 | from PIL import Image 47 | from scipy import linalg 48 | from torch.nn.functional import adaptive_avg_pool2d 49 | 50 | try: 51 | from tqdm import tqdm 52 | except ImportError: 53 | # If tqdm is not available, provide a mock version of it 54 | def tqdm(x): 55 | return x 56 | 57 | from pytorch_fid.inception import InceptionV3 58 | 59 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 60 | parser.add_argument('--batch-size', type=int, default=50, 61 | help='Batch size to use') 62 | parser.add_argument('--num-workers', type=int, 63 | help=('Number of processes to use for data loading. ' 64 | 'Defaults to `min(8, num_cpus)`')) 65 | parser.add_argument('--device', type=str, default=None, 66 | help='Device to use. Like cuda, cuda:0 or cpu') 67 | parser.add_argument('--dims', type=int, default=2048, 68 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 69 | help=('Dimensionality of Inception features to use. ' 70 | 'By default, uses pool3 features')) 71 | parser.add_argument('--save-stats', action='store_true', 72 | help=('Generate an npz archive from a directory of samples. ' 73 | 'The first path is used as input and the second as output.')) 74 | parser.add_argument('path', type=str, nargs=2, 75 | help=('Paths to the generated images or ' 76 | 'to .npz statistic files')) 77 | 78 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 79 | 'tif', 'tiff', 'webp'} 80 | 81 | 82 | class ImagePathDataset(torch.utils.data.Dataset): 83 | def __init__(self, files, transforms=None): 84 | self.files = files 85 | self.transforms = transforms 86 | 87 | def __len__(self): 88 | return len(self.files) 89 | 90 | def __getitem__(self, i): 91 | path = self.files[i] 92 | img = Image.open(path).convert('RGB') 93 | if self.transforms is not None: 94 | img = self.transforms(img) 95 | return img 96 | 97 | 98 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', 99 | num_workers=1): 100 | """Calculates the activations of the pool_3 layer for all images. 101 | 102 | Params: 103 | -- files : List of image files paths 104 | -- model : Instance of inception model 105 | -- batch_size : Batch size of images for the model to process at once. 106 | Make sure that the number of samples is a multiple of 107 | the batch size, otherwise some samples are ignored. This 108 | behavior is retained to match the original FID score 109 | implementation. 110 | -- dims : Dimensionality of features returned by Inception 111 | -- device : Device to run calculations 112 | -- num_workers : Number of parallel dataloader workers 113 | 114 | Returns: 115 | -- A numpy array of dimension (num images, dims) that contains the 116 | activations of the given tensor when feeding inception with the 117 | query tensor. 118 | """ 119 | model.eval() 120 | 121 | if batch_size > len(files): 122 | print(('Warning: batch size is bigger than the data size. ' 123 | 'Setting batch size to data size')) 124 | batch_size = len(files) 125 | 126 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 127 | dataloader = torch.utils.data.DataLoader(dataset, 128 | batch_size=batch_size, 129 | shuffle=False, 130 | drop_last=False, 131 | num_workers=num_workers) 132 | 133 | pred_arr = np.empty((len(files), dims)) 134 | 135 | start_idx = 0 136 | 137 | for batch in tqdm(dataloader): 138 | batch = batch.to(device) 139 | 140 | with torch.no_grad(): 141 | pred = model(batch)[0] 142 | 143 | # If model output is not scalar, apply global spatial average pooling. 144 | # This happens if you choose a dimensionality not equal 2048. 145 | if pred.size(2) != 1 or pred.size(3) != 1: 146 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 147 | 148 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 149 | 150 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 151 | 152 | start_idx = start_idx + pred.shape[0] 153 | 154 | return pred_arr 155 | 156 | 157 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 158 | """Numpy implementation of the Frechet Distance. 159 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 160 | and X_2 ~ N(mu_2, C_2) is 161 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 162 | 163 | Stable version by Dougal J. Sutherland. 164 | 165 | Params: 166 | -- mu1 : Numpy array containing the activations of a layer of the 167 | inception net (like returned by the function 'get_predictions') 168 | for generated samples. 169 | -- mu2 : The sample mean over activations, precalculated on an 170 | representative data set. 171 | -- sigma1: The covariance matrix over activations for generated samples. 172 | -- sigma2: The covariance matrix over activations, precalculated on an 173 | representative data set. 174 | 175 | Returns: 176 | -- : The Frechet Distance. 177 | """ 178 | 179 | mu1 = np.atleast_1d(mu1) 180 | mu2 = np.atleast_1d(mu2) 181 | 182 | sigma1 = np.atleast_2d(sigma1) 183 | sigma2 = np.atleast_2d(sigma2) 184 | 185 | assert mu1.shape == mu2.shape, \ 186 | 'Training and test mean vectors have different lengths' 187 | assert sigma1.shape == sigma2.shape, \ 188 | 'Training and test covariances have different dimensions' 189 | 190 | diff = mu1 - mu2 191 | 192 | # Product might be almost singular 193 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 194 | if not np.isfinite(covmean).all(): 195 | msg = ('fid calculation produces singular product; ' 196 | 'adding %s to diagonal of cov estimates') % eps 197 | print(msg) 198 | offset = np.eye(sigma1.shape[0]) * eps 199 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 200 | 201 | # Numerical error might give slight imaginary component 202 | if np.iscomplexobj(covmean): 203 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 204 | m = np.max(np.abs(covmean.imag)) 205 | raise ValueError('Imaginary component {}'.format(m)) 206 | covmean = covmean.real 207 | 208 | tr_covmean = np.trace(covmean) 209 | 210 | return (diff.dot(diff) + np.trace(sigma1) 211 | + np.trace(sigma2) - 2 * tr_covmean) 212 | 213 | 214 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 215 | device='cpu', num_workers=1): 216 | """Calculation of the statistics used by the FID. 217 | Params: 218 | -- files : List of image files paths 219 | -- model : Instance of inception model 220 | -- batch_size : The images numpy array is split into batches with 221 | batch size batch_size. A reasonable batch size 222 | depends on the hardware. 223 | -- dims : Dimensionality of features returned by Inception 224 | -- device : Device to run calculations 225 | -- num_workers : Number of parallel dataloader workers 226 | 227 | Returns: 228 | -- mu : The mean over samples of the activations of the pool_3 layer of 229 | the inception model. 230 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 231 | the inception model. 232 | """ 233 | act = get_activations(files, model, batch_size, dims, device, num_workers) 234 | mu = np.mean(act, axis=0) 235 | sigma = np.cov(act, rowvar=False) 236 | return mu, sigma 237 | 238 | 239 | def compute_statistics_of_path(path, model, batch_size, dims, device, 240 | num_workers=1): 241 | if path.endswith('.npz'): 242 | with np.load(path) as f: 243 | m, s = f['mu'][:], f['sigma'][:] 244 | else: 245 | path = pathlib.Path(path) 246 | files = sorted([file for ext in IMAGE_EXTENSIONS 247 | for file in path.rglob('*.{}'.format(ext))]) 248 | m, s = calculate_activation_statistics(files, model, batch_size, 249 | dims, device, num_workers) 250 | 251 | return m, s 252 | 253 | 254 | def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1): 255 | """Calculates the FID of two paths""" 256 | for p in paths: 257 | if not os.path.exists(p): 258 | raise RuntimeError('Invalid path: %s' % p) 259 | 260 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 261 | 262 | model = InceptionV3([block_idx]).to(device) 263 | 264 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 265 | dims, device, num_workers) 266 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 267 | dims, device, num_workers) 268 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 269 | 270 | return fid_value 271 | 272 | 273 | def save_fid_stats(paths, batch_size, device, dims, num_workers=1): 274 | """Calculates the FID of two paths""" 275 | if not os.path.exists(paths[0]): 276 | raise RuntimeError('Invalid path: %s' % paths[0]) 277 | 278 | if os.path.exists(paths[1]): 279 | raise RuntimeError('Existing output file: %s' % paths[1]) 280 | 281 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 282 | 283 | model = InceptionV3([block_idx]).to(device) 284 | 285 | print(f"Saving statistics for {paths[0]}") 286 | 287 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 288 | dims, device, num_workers) 289 | 290 | np.savez_compressed(paths[1], mu=m1, sigma=s1) 291 | 292 | 293 | def main(): 294 | args = parser.parse_args() 295 | 296 | if args.device is None: 297 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 298 | else: 299 | device = torch.device(args.device) 300 | 301 | if args.num_workers is None: 302 | num_avail_cpus = len(os.sched_getaffinity(0)) 303 | num_workers = min(num_avail_cpus, 8) 304 | else: 305 | num_workers = args.num_workers 306 | 307 | if args.save_stats: 308 | save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers) 309 | return 310 | 311 | fid_value = calculate_fid_given_paths(args.path, 312 | args.batch_size, 313 | device, 314 | args.dims, 315 | num_workers) 316 | print('FID: ', fid_value) 317 | 318 | 319 | if __name__ == '__main__': 320 | main() 321 | -------------------------------------------------------------------------------- /pytorch_fid/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=(DEFAULT_BLOCK_INDEX,), 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = _inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def _inception_v3(*args, **kwargs): 167 | """Wraps `torchvision.models.inception_v3` 168 | 169 | Skips default weight inititialization if supported by torchvision version. 170 | See https://github.com/mseitzer/pytorch-fid/issues/28. 171 | """ 172 | try: 173 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 174 | except ValueError: 175 | # Just a caution against weird version strings 176 | version = (0,) 177 | 178 | if version >= (0, 6): 179 | kwargs['init_weights'] = False 180 | 181 | return torchvision.models.inception_v3(*args, **kwargs) 182 | 183 | 184 | def fid_inception_v3(): 185 | """Build pretrained Inception model for FID computation 186 | 187 | The Inception model for FID computation uses a different set of weights 188 | and has a slightly different structure than torchvision's Inception. 189 | 190 | This method first constructs torchvision's Inception and then patches the 191 | necessary parts that are different in the FID Inception model. 192 | """ 193 | inception = _inception_v3(num_classes=1008, 194 | aux_logits=False, 195 | pretrained=False) 196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 203 | inception.Mixed_7b = FIDInceptionE_1(1280) 204 | inception.Mixed_7c = FIDInceptionE_2(2048) 205 | 206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 207 | inception.load_state_dict(state_dict) 208 | return inception 209 | 210 | 211 | class FIDInceptionA(torchvision.models.inception.InceptionA): 212 | """InceptionA block patched for FID computation""" 213 | def __init__(self, in_channels, pool_features): 214 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 215 | 216 | def forward(self, x): 217 | branch1x1 = self.branch1x1(x) 218 | 219 | branch5x5 = self.branch5x5_1(x) 220 | branch5x5 = self.branch5x5_2(branch5x5) 221 | 222 | branch3x3dbl = self.branch3x3dbl_1(x) 223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 225 | 226 | # Patch: Tensorflow's average pool does not use the padded zero's in 227 | # its average calculation 228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 229 | count_include_pad=False) 230 | branch_pool = self.branch_pool(branch_pool) 231 | 232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 233 | return torch.cat(outputs, 1) 234 | 235 | 236 | class FIDInceptionC(torchvision.models.inception.InceptionC): 237 | """InceptionC block patched for FID computation""" 238 | def __init__(self, in_channels, channels_7x7): 239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 240 | 241 | def forward(self, x): 242 | branch1x1 = self.branch1x1(x) 243 | 244 | branch7x7 = self.branch7x7_1(x) 245 | branch7x7 = self.branch7x7_2(branch7x7) 246 | branch7x7 = self.branch7x7_3(branch7x7) 247 | 248 | branch7x7dbl = self.branch7x7dbl_1(x) 249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 253 | 254 | # Patch: Tensorflow's average pool does not use the padded zero's in 255 | # its average calculation 256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 257 | count_include_pad=False) 258 | branch_pool = self.branch_pool(branch_pool) 259 | 260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 261 | return torch.cat(outputs, 1) 262 | 263 | 264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 265 | """First InceptionE block patched for FID computation""" 266 | def __init__(self, in_channels): 267 | super(FIDInceptionE_1, self).__init__(in_channels) 268 | 269 | def forward(self, x): 270 | branch1x1 = self.branch1x1(x) 271 | 272 | branch3x3 = self.branch3x3_1(x) 273 | branch3x3 = [ 274 | self.branch3x3_2a(branch3x3), 275 | self.branch3x3_2b(branch3x3), 276 | ] 277 | branch3x3 = torch.cat(branch3x3, 1) 278 | 279 | branch3x3dbl = self.branch3x3dbl_1(x) 280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 281 | branch3x3dbl = [ 282 | self.branch3x3dbl_3a(branch3x3dbl), 283 | self.branch3x3dbl_3b(branch3x3dbl), 284 | ] 285 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 286 | 287 | # Patch: Tensorflow's average pool does not use the padded zero's in 288 | # its average calculation 289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 290 | count_include_pad=False) 291 | branch_pool = self.branch_pool(branch_pool) 292 | 293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 294 | return torch.cat(outputs, 1) 295 | 296 | 297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 298 | """Second InceptionE block patched for FID computation""" 299 | def __init__(self, in_channels): 300 | super(FIDInceptionE_2, self).__init__(in_channels) 301 | 302 | def forward(self, x): 303 | branch1x1 = self.branch1x1(x) 304 | 305 | branch3x3 = self.branch3x3_1(x) 306 | branch3x3 = [ 307 | self.branch3x3_2a(branch3x3), 308 | self.branch3x3_2b(branch3x3), 309 | ] 310 | branch3x3 = torch.cat(branch3x3, 1) 311 | 312 | branch3x3dbl = self.branch3x3dbl_1(x) 313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 314 | branch3x3dbl = [ 315 | self.branch3x3dbl_3a(branch3x3dbl), 316 | self.branch3x3dbl_3b(branch3x3dbl), 317 | ] 318 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 319 | 320 | # Patch: The FID Inception model uses max pooling instead of average 321 | # pooling. This is likely an error in this specific Inception 322 | # implementation, as other Inception models use average pooling here 323 | # (which matches the description in the paper). 324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 325 | branch_pool = self.branch_pool(branch_pool) 326 | 327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 328 | return torch.cat(outputs, 1) 329 | -------------------------------------------------------------------------------- /robust_unsupervised/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .io_utils import * 3 | from .loss_function import * 4 | from .optimizer import * 5 | from .variables import * 6 | -------------------------------------------------------------------------------- /robust_unsupervised/io_utils.py: -------------------------------------------------------------------------------- 1 | from robust_unsupervised.prelude import * 2 | from robust_unsupervised.variables import * 3 | 4 | import shutil 5 | import torch_utils as torch_utils 6 | import torch_utils.misc as misc 7 | import contextlib 8 | 9 | import PIL.Image as Image 10 | 11 | 12 | def open_generator(pkl_path: str, refresh=True, float=True, ema=True) -> networks.Generator: 13 | print(f"Loading generator from {pkl_path}...") 14 | 15 | with dnnlib.util.open_url(pkl_path) as fp: 16 | G = legacy.load_network_pkl(fp)["G_ema" if ema else "G"].cuda().eval() 17 | if float: 18 | G = G.float() 19 | 20 | if refresh: 21 | with torch.no_grad(): 22 | old_G = G 23 | G = networks.Generator(*old_G.init_args, **old_G.init_kwargs).cuda() 24 | misc.copy_params_and_buffers(old_G, G, require_all=True) 25 | for param in G.parameters(): 26 | param.requires_grad = False 27 | 28 | return G 29 | 30 | 31 | def open_image(path: str, resolution: int): 32 | image = TF.to_tensor(Image.open(path)).cuda().unsqueeze(0)[:, :3] 33 | image = TF.center_crop(image, min(image.shape[2:])) 34 | return F.interpolate(image, resolution, mode="area") 35 | 36 | 37 | def resize_for_logging(x: torch.Tensor, resolution: int) -> torch.Tensor: 38 | return F.interpolate( 39 | x, 40 | size=(resolution, resolution), 41 | mode="nearest" if x.shape[-1] <= resolution else "area", 42 | ) 43 | 44 | 45 | @contextlib.contextmanager 46 | def directory(dir_path: str) -> None: 47 | "Context manager for entering a directory, while automatically creating it if it does not exist." 48 | if not os.path.exists(dir_path): 49 | os.makedirs(dir_path) 50 | 51 | cwd = os.getcwd() 52 | os.chdir(dir_path) 53 | yield 54 | os.chdir(cwd) 55 | -------------------------------------------------------------------------------- /robust_unsupervised/loss_function.py: -------------------------------------------------------------------------------- 1 | from .prelude import * 2 | from lpips import LPIPS 3 | 4 | 5 | class MultiscaleLPIPS: 6 | def __init__( 7 | self, 8 | min_loss_res: int = 16, 9 | level_weights: List[float] = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 10 | l1_weight: float = 0.1 11 | ): 12 | super().__init__() 13 | self.min_loss_res = min_loss_res 14 | self.weights = level_weights 15 | self.l1_weight = l1_weight 16 | self.lpips_network = LPIPS(net="vgg", verbose=False).cuda() 17 | 18 | def measure_lpips(self, x, y, mask): 19 | if mask is not None: 20 | # To avoid biasing the results towards black pixels, but random noise in the masked areas 21 | noise = (torch.randn_like(x) + 0.5) / 2.0 22 | x = x + noise * (1.0 - mask) 23 | y = y + noise * (1.0 - mask) 24 | 25 | return self.lpips_network(x, y, normalize=True).mean() 26 | 27 | def __call__(self, f_hat, x_clean: Tensor, y: Tensor, mask: Optional[Tensor] = None): 28 | x = f_hat(x_clean) 29 | 30 | losses = [] 31 | 32 | if mask is not None: 33 | mask = F.interpolate(mask, y.shape[-1], mode="area") 34 | 35 | for weight in self.weights: 36 | # At extremely low resolutions, LPIPS stops making sense, so omit those 37 | if y.shape[-1] <= self.min_loss_res: 38 | break 39 | 40 | if weight > 0: 41 | loss = self.measure_lpips(x, y, mask) 42 | losses.append(weight * loss) 43 | 44 | if mask is not None: 45 | mask = F.avg_pool2d(mask, 2) 46 | 47 | x = F.avg_pool2d(x, 2) 48 | x_clean = F.avg_pool2d(x_clean, 2) 49 | y = F.avg_pool2d(y, 2) 50 | 51 | total = torch.stack(losses).sum(dim=0) if len(losses) > 0 else 0.0 52 | l1 = self.l1_weight * F.l1_loss(x, y) 53 | 54 | return total + l1 55 | -------------------------------------------------------------------------------- /robust_unsupervised/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class NGD(torch.optim.SGD): 5 | @torch.no_grad() 6 | def step(self): 7 | for group in self.param_groups: 8 | for param in group["params"]: 9 | assert param.isnan().sum().item() == 0 10 | g = param.grad 11 | g /= g.norm(dim=-1, keepdim=True) 12 | g = torch.nan_to_num( 13 | g, nan=0.0, posinf=0.0, neginf=0.0 14 | ) 15 | param -= group["lr"] * g -------------------------------------------------------------------------------- /robust_unsupervised/prelude.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import copy 4 | import os 5 | 6 | import functools 7 | import sys 8 | import torch.optim as optim 9 | import tqdm 10 | import dataclasses 11 | import numpy as np 12 | import PIL.Image 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch import Tensor 17 | import math 18 | import dnnlib as dnnlib 19 | import dnnlib.legacy as legacy 20 | # import legacy 21 | import shutil 22 | from functools import partial 23 | import itertools 24 | import warnings 25 | from warnings import warn 26 | import datetime 27 | import torchvision.transforms.functional as TF 28 | from torchvision.utils import save_image, make_grid 29 | import training.networks as networks 30 | 31 | from abc import ABC, abstractmethod, abstractstaticmethod, abstractclassmethod 32 | from dataclasses import dataclass, field 33 | 34 | import ssl 35 | ssl._create_default_https_context = ssl._create_unverified_context 36 | 37 | warnings.filterwarnings("ignore", r"Named tensors and all their associated APIs.*") 38 | warnings.filterwarnings("ignore", r"Arguments other than a weight enum.*") 39 | warnings.filterwarnings("ignore", r"The parameter 'pretrained' is deprecated.*") 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /robust_unsupervised/variables.py: -------------------------------------------------------------------------------- 1 | from .prelude import * 2 | 3 | 4 | class Variable(nn.Module): 5 | def __init__(self, G: networks.Generator, data: torch.Tensor): 6 | super().__init__() 7 | self.G = G 8 | self.data = data 9 | 10 | # ------------------------------------ 11 | 12 | @staticmethod 13 | def sample_from(G: networks.Generator, batch_size: int = 1): 14 | raise NotImplementedError 15 | 16 | @staticmethod 17 | def sample_random_from(G: networks.Generator, batch_size: int = 1): 18 | raise NotImplementedError 19 | 20 | def to_input_tensor(self): 21 | raise NotImplementedError 22 | 23 | # ------------------------------------ 24 | 25 | def parameters(self): 26 | return [self.data] 27 | 28 | def to_image(self): 29 | return self.render_image(self.to_input_tensor()) 30 | 31 | def render_image(self, ws: torch.Tensor): # todo 32 | """ 33 | ws shape: [batch_size, num_layers, 512] 34 | """ 35 | return (self.G.synthesis(ws, noise_mode="const", force_fp32=True) + 1.0) / 2.0 36 | 37 | def detach(self): 38 | data = self.data.detach().requires_grad_(self.data.requires_grad) 39 | data = nn.Parameter(data) if isinstance(self.data, nn.Parameter) else self.data 40 | return self.__class__(self.G, data) 41 | 42 | def clone(self): 43 | data = self.data.detach().clone().requires_grad_(self.data.requires_grad) 44 | data = nn.Parameter(data) if isinstance(self.data, nn.Parameter) else self.data 45 | return self.__class__(self.G, data) 46 | 47 | def interpolate(self, other: "Variable", alpha: float = 0.5): 48 | assert self.G == other.G 49 | return self.__class__(self.G, self.data.lerp(other.data, alpha)) 50 | 51 | def __add__(self, other: "Variable"): 52 | return self.from_data(self.data + other.data) 53 | 54 | def __sub__(self, other: "Variable"): 55 | return self.from_data(self.data - other.data) 56 | 57 | def __mul__(self, scalar: float): 58 | return self.from_data(self.data * scalar) 59 | 60 | def unbind(self): 61 | """ 62 | Splits this (batched) variable into a a list of variables with batch size 1. 63 | """ 64 | return [ 65 | self.__class__( 66 | self.G, 67 | nn.Parameter(p.unsqueeze(0)) 68 | if isinstance(self.data, nn.Parameter) 69 | else p.unsqueeze(0), 70 | ) 71 | for p in self.data 72 | ] 73 | 74 | 75 | class WVariable(Variable): 76 | @staticmethod 77 | def sample_from(G: nn.Module, batch_size: int = 1): 78 | data = G.mapping.w_avg.reshape(1, G.w_dim).repeat(batch_size, 1) 79 | 80 | return WVariable(G, nn.Parameter(data)) 81 | 82 | @staticmethod 83 | def sample_random_from(G: nn.Module, batch_size: int = 1): 84 | data = G.mapping( 85 | torch.randn(batch_size, G.z_dim).cuda(), 86 | None, 87 | skip_w_avg_update=True, 88 | )[:, 0] 89 | 90 | return WVariable(G, nn.Parameter(data)) 91 | 92 | def to_input_tensor(self): 93 | return self.data.unsqueeze(1).repeat(1, self.G.num_ws, 1) 94 | 95 | @torch.no_grad() 96 | def truncate(self, truncation: float=1.0): 97 | assert 0.0 <= truncation <= 1.0 98 | self.data.lerp_(self.G.mapping.w_avg.reshape(1, 512), 1.0 - truncation) 99 | return self 100 | 101 | 102 | class WpVariable(Variable): 103 | def __init__(self, G, data: torch.Tensor): 104 | super().__init__(G, data) 105 | 106 | @staticmethod 107 | def sample_from(G: nn.Module, batch_size: int = 1): 108 | data = WVariable.to_input_tensor(WVariable.sample_from(G, batch_size)) 109 | 110 | return WpVariable(G, nn.Parameter(data)) 111 | 112 | @staticmethod 113 | def sample_random_from(G: nn.Module, batch_size: int = 1): 114 | data = ( 115 | G.mapping( 116 | (torch.randn(batch_size * G.mapping.num_ws, G.z_dim).cuda()), 117 | None, 118 | skip_w_avg_update=True, 119 | ) 120 | .mean(dim=1) 121 | .reshape(batch_size, G.mapping.num_ws, G.w_dim) 122 | ) 123 | 124 | return WpVariable(G, nn.Parameter(data)) 125 | 126 | def to_input_tensor(self): 127 | return self.data 128 | 129 | def mix(self, other: "WpVariable", num_layers: float): 130 | return WpVariable( 131 | self.G, 132 | torch.cat( 133 | (self.data[:, :num_layers, :], other.data[:, num_layers:, :]), dim=1 134 | ), 135 | ) 136 | 137 | @staticmethod 138 | def from_W(W: WVariable): 139 | return WpVariable( 140 | W.G, nn.parameter.Parameter(W.to_input_tensor()) 141 | ) 142 | 143 | @torch.no_grad() 144 | def truncate(self, truncation=1.0, *, layer_start = 0, layer_end: Optional[int] = None): 145 | assert 0.0 <= truncation <= 1.0 146 | mu = self.G.mapping.w_avg 147 | target = mu.reshape(1, 1, 512).repeat(1, self.G.mapping.num_ws, 1) 148 | self.data[:, layer_start:layer_end].lerp_(target[:, layer_start:layer_end], 1.0 - truncation) 149 | return self 150 | 151 | 152 | class WppVariable(Variable): 153 | @staticmethod 154 | def sample_from(G: nn.Module, batch_size: int = 1): 155 | data = WVariable.sample_from(G, batch_size).to_input_tensor().repeat(1, 512, 1) 156 | 157 | return WppVariable(G, nn.Parameter(data)) 158 | 159 | @staticmethod 160 | def sample_random_from(G: nn.Module, batch_size: int = 1): 161 | data = ( 162 | WVariable.sample_random_from(G, batch_size) 163 | .to_input_tensor() 164 | .repeat(1, 512, 1) 165 | ) 166 | 167 | return WppVariable(G, nn.Parameter(data)) 168 | 169 | @staticmethod 170 | def from_w(W: WVariable): 171 | data = W.data.detach().repeat(1, 512 * W.G.num_ws, 1) 172 | 173 | return WppVariable(W.G, nn.parameter.Parameter(data)) 174 | 175 | @staticmethod 176 | def from_Wp(Wp: WpVariable): 177 | data = Wp.data.detach().repeat_interleave(512, dim=1) 178 | 179 | return WppVariable(Wp.G, nn.parameter.Parameter(data)) 180 | 181 | def to_input_tensor(self): 182 | return self.data 183 | 184 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from cli import parse_config 2 | import glob 3 | 4 | import benchmark 5 | from benchmark import Task, Degradation 6 | from robust_unsupervised import * 7 | 8 | 9 | config = parse_config() 10 | benchmark.config.resolution = config.resolution 11 | 12 | print(config.name) 13 | timestamp = datetime.datetime.now().isoformat(timespec="seconds").replace(":", "") 14 | 15 | G = open_generator(config.pkl_path) 16 | loss_fn = MultiscaleLPIPS() 17 | 18 | 19 | def run_phase(label: str, variable: Variable, lr: float): 20 | # Run optimization loop 21 | optimizer = NGD(variable.parameters(), lr=lr) 22 | try: 23 | for _ in tqdm.tqdm(range(150), desc=label): 24 | x = variable.to_image() 25 | loss = loss_fn(degradation.degrade_prediction, x, target, degradation.mask).mean() 26 | 27 | optimizer.zero_grad() 28 | loss.backward() 29 | optimizer.step() 30 | 31 | except KeyboardInterrupt: 32 | pass 33 | 34 | # Log results 35 | suffix = "_" + label 36 | pred = resize_for_logging(variable.to_image(), config.resolution) 37 | 38 | approx_degraded_pred = degradation.degrade_prediction(pred) 39 | degraded_pred = degradation.degrade_ground_truth(pred) 40 | 41 | save_image(pred, f"pred{suffix}.png", padding=0) 42 | save_image(degraded_pred, f"degraded_pred{suffix}.png", padding=0) 43 | 44 | save_image( 45 | torch.cat([approx_degraded_pred, degraded_pred]), 46 | f"degradation_approximation{suffix}.jpg", 47 | padding=0, 48 | ) 49 | 50 | save_image( 51 | torch.cat( 52 | [ 53 | ground_truth, 54 | resize_for_logging(target, config.resolution), 55 | resize_for_logging(degraded_pred, config.resolution), 56 | pred, 57 | ] 58 | ), 59 | f"side_by_side{suffix}.jpg", 60 | padding=0, 61 | ) 62 | save_image( 63 | torch.cat([resize_for_logging(target, config.resolution), pred]), 64 | f"result{suffix}.jpg", 65 | padding=0, 66 | ) 67 | save_image( 68 | torch.cat([target, degraded_pred, (target - degraded_pred).abs()]), 69 | f"fidelity{suffix}.jpg", 70 | padding=0, 71 | ) 72 | save_image( 73 | torch.cat([ground_truth, pred, (ground_truth - pred).abs()]), 74 | f"accuracy{suffix}.jpg", 75 | padding=0, 76 | ) 77 | 78 | 79 | if __name__ == '__main__': 80 | if config.tasks == "single": 81 | tasks = benchmark.single_tasks 82 | elif config.tasks == "composed": 83 | tasks = benchmark.composed_tasks 84 | elif config.tasks == "all": 85 | tasks = benchmark.all_tasks 86 | elif config.tasks == "custom": 87 | # Implement your own degradation here 88 | class YourDegradation: 89 | def degrade_ground_truth(self, x): 90 | "The true degradation you are attempting to invert." 91 | raise NotImplementedError 92 | 93 | def degrade_prediction(self, x): 94 | """ 95 | Differentiable approximation to the degradation in question. 96 | Can be identical to the true degradation if it is invertible. 97 | """ 98 | raise NotImplementedError 99 | tasks = [ 100 | benchmark.Task( 101 | constructor=YourDegradation, 102 | # These labels are just for the output folder structure 103 | name="your_degradation", 104 | category="single", 105 | level="M", 106 | ) 107 | ] 108 | else: 109 | raise Exception("Invalid task name") 110 | 111 | for task in tasks: 112 | experiment_path = f"out/{config.name}/{timestamp}/{task.category}/{task.name}/{task.level}/" 113 | 114 | image_paths = sorted( 115 | [ 116 | os.path.abspath(path) 117 | for path in ( 118 | glob.glob(config.dataset_path + "/**/*.png", recursive=True) 119 | + glob.glob(config.dataset_path + "/**/*.jpg", recursive=True) 120 | + glob.glob(config.dataset_path + "/**/*.jpeg", recursive=True) 121 | + glob.glob(config.dataset_path + "/**/*.tif", recursive=True) 122 | ) 123 | ] 124 | ) 125 | assert len(image_paths) > 0, "No images found!" 126 | 127 | with directory(experiment_path): 128 | print(experiment_path) 129 | print(os.path.abspath(config.dataset_path)) 130 | 131 | for j, image_path in enumerate(image_paths): 132 | with directory(f"inversions/{j:04d}"): 133 | print(f"- {j:04d}") 134 | 135 | ground_truth = open_image(image_path, config.resolution) 136 | degradation = task.init_degradation() 137 | save_image(ground_truth, f"ground_truth.png") 138 | target = degradation.degrade_ground_truth(ground_truth) 139 | save_image(target, f"target.png") 140 | 141 | W_variable = WVariable.sample_from(G) 142 | run_phase("W", W_variable, config.global_lr_scale * 0.08) 143 | 144 | Wp_variable = WpVariable.from_W(W_variable) 145 | run_phase("W+", Wp_variable, config.global_lr_scale * 0.02) 146 | 147 | Wpp_variable = WppVariable.from_Wp(Wp_variable) 148 | run_phase("W++", Wpp_variable, config.global_lr_scale * 0.005) -------------------------------------------------------------------------------- /stylegan2_ada/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= 98 | -------------------------------------------------------------------------------- /stylegan2_ada/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | 11 | -------------------------------------------------------------------------------- /stylegan2_ada/dnnlib/legacy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import click 10 | import pickle 11 | import re 12 | import copy 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | import torch_utils as torch_utils 17 | from torch_utils import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def load_network_pkl(f, force_fp16=False): 22 | data = _LegacyUnpickler(f).load() 23 | 24 | # Legacy TensorFlow pickle => convert. 25 | if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): 26 | tf_G, tf_D, tf_Gs = data 27 | G = convert_tf_generator(tf_G) 28 | D = convert_tf_discriminator(tf_D) 29 | G_ema = convert_tf_generator(tf_Gs) 30 | data = dict(G=G, D=D, G_ema=G_ema) 31 | 32 | # Add missing fields. 33 | if 'training_set_kwargs' not in data: 34 | data['training_set_kwargs'] = None 35 | if 'augment_pipe' not in data: 36 | data['augment_pipe'] = None 37 | 38 | # Validate contents. 39 | assert isinstance(data['G'], torch.nn.Module) 40 | assert isinstance(data['D'], torch.nn.Module) 41 | assert isinstance(data['G_ema'], torch.nn.Module) 42 | assert isinstance(data['training_set_kwargs'], (dict, type(None))) 43 | assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) 44 | 45 | # Force FP16. 46 | if force_fp16: 47 | for key in ['G', 'D', 'G_ema']: 48 | old = data[key] 49 | kwargs = copy.deepcopy(old.init_kwargs) 50 | if key.startswith('G'): 51 | kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {})) 52 | kwargs.synthesis_kwargs.num_fp16_res = 4 53 | kwargs.synthesis_kwargs.conv_clamp = 256 54 | if key.startswith('D'): 55 | kwargs.num_fp16_res = 4 56 | kwargs.conv_clamp = 256 57 | if kwargs != old.init_kwargs: 58 | new = type(old)(**kwargs).eval().requires_grad_(False) 59 | misc.copy_params_and_buffers(old, new, require_all=True) 60 | data[key] = new 61 | return data 62 | 63 | #---------------------------------------------------------------------------- 64 | 65 | class _TFNetworkStub(dnnlib.EasyDict): 66 | pass 67 | 68 | class _LegacyUnpickler(pickle.Unpickler): 69 | def find_class(self, module, name): 70 | if module == 'dnnlib.tflib.network' and name == 'Network': 71 | return _TFNetworkStub 72 | 73 | try: 74 | return super().find_class(module, name) 75 | except: 76 | breakpoint() 77 | 78 | #---------------------------------------------------------------------------- 79 | 80 | def _collect_tf_params(tf_net): 81 | # pylint: disable=protected-access 82 | tf_params = dict() 83 | def recurse(prefix, tf_net): 84 | for name, value in tf_net.variables: 85 | tf_params[prefix + name] = value 86 | for name, comp in tf_net.components.items(): 87 | recurse(prefix + name + '/', comp) 88 | recurse('', tf_net) 89 | return tf_params 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | def _populate_module_params(module, *patterns): 94 | for name, tensor in misc.named_params_and_buffers(module): 95 | found = False 96 | value = None 97 | for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): 98 | match = re.fullmatch(pattern, name) 99 | if match: 100 | found = True 101 | if value_fn is not None: 102 | value = value_fn(*match.groups()) 103 | break 104 | try: 105 | assert found 106 | if value is not None: 107 | tensor.copy_(torch.from_numpy(np.array(value))) 108 | except: 109 | print(name, list(tensor.shape)) 110 | raise 111 | 112 | #---------------------------------------------------------------------------- 113 | 114 | def convert_tf_generator(tf_G): 115 | if tf_G.version < 4: 116 | raise ValueError('TensorFlow pickle version too low') 117 | 118 | # Collect kwargs. 119 | tf_kwargs = tf_G.static_kwargs 120 | known_kwargs = set() 121 | def kwarg(tf_name, default=None, none=None): 122 | known_kwargs.add(tf_name) 123 | val = tf_kwargs.get(tf_name, default) 124 | return val if val is not None else none 125 | 126 | # Convert kwargs. 127 | kwargs = dnnlib.EasyDict( 128 | z_dim = kwarg('latent_size', 512), 129 | c_dim = kwarg('label_size', 0), 130 | w_dim = kwarg('dlatent_size', 512), 131 | img_resolution = kwarg('resolution', 1024), 132 | img_channels = kwarg('num_channels', 3), 133 | mapping_kwargs = dnnlib.EasyDict( 134 | num_layers = kwarg('mapping_layers', 8), 135 | embed_features = kwarg('label_fmaps', None), 136 | layer_features = kwarg('mapping_fmaps', None), 137 | activation = kwarg('mapping_nonlinearity', 'lrelu'), 138 | lr_multiplier = kwarg('mapping_lrmul', 0.01), 139 | w_avg_beta = kwarg('w_avg_beta', 0.995, none=1), 140 | ), 141 | synthesis_kwargs = dnnlib.EasyDict( 142 | channel_base = kwarg('fmap_base', 16384) * 2, 143 | channel_max = kwarg('fmap_max', 512), 144 | num_fp16_res = kwarg('num_fp16_res', 0), 145 | conv_clamp = kwarg('conv_clamp', None), 146 | architecture = kwarg('architecture', 'skip'), 147 | resample_filter = kwarg('resample_kernel', [1,3,3,1]), 148 | use_noise = kwarg('use_noise', True), 149 | activation = kwarg('nonlinearity', 'lrelu'), 150 | ), 151 | ) 152 | 153 | # Check for unknown kwargs. 154 | kwarg('truncation_psi') 155 | kwarg('truncation_cutoff') 156 | kwarg('style_mixing_prob') 157 | kwarg('structure') 158 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) 159 | if len(unknown_kwargs) > 0: 160 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) 161 | 162 | # Collect params. 163 | tf_params = _collect_tf_params(tf_G) 164 | for name, value in list(tf_params.items()): 165 | match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) 166 | if match: 167 | r = kwargs.img_resolution // (2 ** int(match.group(1))) 168 | tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value 169 | kwargs.synthesis.kwargs.architecture = 'orig' 170 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') 171 | 172 | # Convert params. 173 | from training import networks 174 | G = networks.Generator(**kwargs).eval().requires_grad_(False) 175 | # pylint: disable=unnecessary-lambda 176 | _populate_module_params(G, 177 | r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], 178 | r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), 179 | r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], 180 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), 181 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], 182 | r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], 183 | r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), 184 | r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], 185 | r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], 186 | r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], 187 | r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), 188 | r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, 189 | r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1), 190 | r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], 191 | r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0], 192 | r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], 193 | r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(), 194 | r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, 195 | r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1), 196 | r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], 197 | r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0], 198 | r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], 199 | r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), 200 | r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, 201 | r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1), 202 | r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], 203 | r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), 204 | r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, 205 | r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1), 206 | r'.*\.resample_filter', None, 207 | ) 208 | return G 209 | 210 | #---------------------------------------------------------------------------- 211 | 212 | def convert_tf_discriminator(tf_D): 213 | if tf_D.version < 4: 214 | raise ValueError('TensorFlow pickle version too low') 215 | 216 | # Collect kwargs. 217 | tf_kwargs = tf_D.static_kwargs 218 | known_kwargs = set() 219 | def kwarg(tf_name, default=None): 220 | known_kwargs.add(tf_name) 221 | return tf_kwargs.get(tf_name, default) 222 | 223 | # Convert kwargs. 224 | kwargs = dnnlib.EasyDict( 225 | c_dim = kwarg('label_size', 0), 226 | img_resolution = kwarg('resolution', 1024), 227 | img_channels = kwarg('num_channels', 3), 228 | architecture = kwarg('architecture', 'resnet'), 229 | channel_base = kwarg('fmap_base', 16384) * 2, 230 | channel_max = kwarg('fmap_max', 512), 231 | num_fp16_res = kwarg('num_fp16_res', 0), 232 | conv_clamp = kwarg('conv_clamp', None), 233 | cmap_dim = kwarg('mapping_fmaps', None), 234 | block_kwargs = dnnlib.EasyDict( 235 | activation = kwarg('nonlinearity', 'lrelu'), 236 | resample_filter = kwarg('resample_kernel', [1,3,3,1]), 237 | freeze_layers = kwarg('freeze_layers', 0), 238 | ), 239 | mapping_kwargs = dnnlib.EasyDict( 240 | num_layers = kwarg('mapping_layers', 0), 241 | embed_features = kwarg('mapping_fmaps', None), 242 | layer_features = kwarg('mapping_fmaps', None), 243 | activation = kwarg('nonlinearity', 'lrelu'), 244 | lr_multiplier = kwarg('mapping_lrmul', 0.1), 245 | ), 246 | epilogue_kwargs = dnnlib.EasyDict( 247 | mbstd_group_size = kwarg('mbstd_group_size', None), 248 | mbstd_num_channels = kwarg('mbstd_num_features', 1), 249 | activation = kwarg('nonlinearity', 'lrelu'), 250 | ), 251 | ) 252 | 253 | # Check for unknown kwargs. 254 | kwarg('structure') 255 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) 256 | if len(unknown_kwargs) > 0: 257 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) 258 | 259 | # Collect params. 260 | tf_params = _collect_tf_params(tf_D) 261 | for name, value in list(tf_params.items()): 262 | match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) 263 | if match: 264 | r = kwargs.img_resolution // (2 ** int(match.group(1))) 265 | tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value 266 | kwargs.architecture = 'orig' 267 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') 268 | 269 | # Convert params. 270 | from training import networks 271 | D = networks.Discriminator(**kwargs).eval().requires_grad_(False) 272 | # pylint: disable=unnecessary-lambda 273 | _populate_module_params(D, 274 | r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), 275 | r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'], 276 | r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1), 277 | r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], 278 | r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), 279 | r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(), 280 | r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'], 281 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(), 282 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'], 283 | r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), 284 | r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'], 285 | r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(), 286 | r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'], 287 | r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(), 288 | r'b4\.out\.bias', lambda: tf_params[f'Output/bias'], 289 | r'.*\.resample_filter', None, 290 | ) 291 | return D 292 | 293 | #---------------------------------------------------------------------------- 294 | 295 | @click.command() 296 | @click.option('--source', help='Input pickle', required=True, metavar='PATH') 297 | @click.option('--dest', help='Output pickle', required=True, metavar='PATH') 298 | @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True) 299 | def convert_network_pickle(source, dest, force_fp16): 300 | """Convert legacy network pickle into the native PyTorch format. 301 | 302 | The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. 303 | It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. 304 | 305 | Example: 306 | 307 | \b 308 | python legacy.py \\ 309 | --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ 310 | --dest=stylegan2-cat-config-f.pkl 311 | """ 312 | print(f'Loading "{source}"...') 313 | with dnnlib.util.open_url(source) as f: 314 | data = load_network_pkl(f, force_fp16=force_fp16) 315 | print(f'Saving "{dest}"...') 316 | with open(dest, 'wb') as f: 317 | pickle.dump(data, f) 318 | print('Done.') 319 | 320 | #---------------------------------------------------------------------------- 321 | 322 | if __name__ == "__main__": 323 | convert_network_pickle() # pylint: disable=no-value-for-parameter 324 | 325 | #---------------------------------------------------------------------------- 326 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib as dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to suppress known warnings in torch.jit.trace(). 68 | 69 | class suppress_tracer_warnings(warnings.catch_warnings): 70 | def __enter__(self): 71 | super().__enter__() 72 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning) 73 | return self 74 | 75 | #---------------------------------------------------------------------------- 76 | # Assert that the shape of a tensor matches the given list of integers. 77 | # None indicates that the size of a dimension is allowed to vary. 78 | # Performs symbolic assertion when used in torch.jit.trace(). 79 | 80 | def assert_shape(tensor, ref_shape): 81 | if tensor.ndim != len(ref_shape): 82 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 83 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 84 | if ref_size is None: 85 | pass 86 | elif isinstance(ref_size, torch.Tensor): 87 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 88 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 89 | elif isinstance(size, torch.Tensor): 90 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 91 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 92 | elif size != ref_size: 93 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 94 | 95 | #---------------------------------------------------------------------------- 96 | # Function decorator that calls torch.autograd.profiler.record_function(). 97 | 98 | def profiled_function(fn): 99 | def decorator(*args, **kwargs): 100 | with torch.autograd.profiler.record_function(fn.__name__): 101 | return fn(*args, **kwargs) 102 | decorator.__name__ = fn.__name__ 103 | return decorator 104 | 105 | #---------------------------------------------------------------------------- 106 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 107 | # indefinitely, shuffling items as it goes. 108 | 109 | class InfiniteSampler(torch.utils.data.Sampler): 110 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 111 | assert len(dataset) > 0 112 | assert num_replicas > 0 113 | assert 0 <= rank < num_replicas 114 | assert 0 <= window_size <= 1 115 | super().__init__(dataset) 116 | self.dataset = dataset 117 | self.rank = rank 118 | self.num_replicas = num_replicas 119 | self.shuffle = shuffle 120 | self.seed = seed 121 | self.window_size = window_size 122 | 123 | def __iter__(self): 124 | order = np.arange(len(self.dataset)) 125 | rnd = None 126 | window = 0 127 | if self.shuffle: 128 | rnd = np.random.RandomState(self.seed) 129 | rnd.shuffle(order) 130 | window = int(np.rint(order.size * self.window_size)) 131 | 132 | idx = 0 133 | while True: 134 | i = idx % order.size 135 | if idx % self.num_replicas == self.rank: 136 | yield order[i] 137 | if window >= 2: 138 | j = (i - rnd.randint(window)) % order.size 139 | order[i], order[j] = order[j], order[i] 140 | idx += 1 141 | 142 | #---------------------------------------------------------------------------- 143 | # Utilities for operating with torch.nn.Module parameters and buffers. 144 | 145 | def params_and_buffers(module): 146 | assert isinstance(module, torch.nn.Module) 147 | return list(module.parameters()) + list(module.buffers()) 148 | 149 | def named_params_and_buffers(module): 150 | assert isinstance(module, torch.nn.Module) 151 | return list(module.named_parameters()) + list(module.named_buffers()) 152 | 153 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 154 | assert isinstance(src_module, torch.nn.Module) 155 | assert isinstance(dst_module, torch.nn.Module) 156 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} 157 | for name, tensor in named_params_and_buffers(dst_module): 158 | assert (name in src_tensors) or (not require_all) 159 | if name in src_tensors: 160 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 161 | 162 | #---------------------------------------------------------------------------- 163 | # Context manager for easily enabling/disabling DistributedDataParallel 164 | # synchronization. 165 | 166 | @contextlib.contextmanager 167 | def ddp_sync(module, sync): 168 | assert isinstance(module, torch.nn.Module) 169 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 170 | yield 171 | else: 172 | with module.no_sync(): 173 | yield 174 | 175 | #---------------------------------------------------------------------------- 176 | # Check DistributedDataParallel consistency across processes. 177 | 178 | def check_ddp_consistency(module, ignore_regex=None): 179 | assert isinstance(module, torch.nn.Module) 180 | for name, tensor in named_params_and_buffers(module): 181 | fullname = type(module).__name__ + '.' + name 182 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 183 | continue 184 | tensor = tensor.detach() 185 | other = tensor.clone() 186 | torch.distributed.broadcast(tensor=other, src=0) 187 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname 188 | 189 | #---------------------------------------------------------------------------- 190 | # Print summary table of module hierarchy. 191 | 192 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 193 | assert isinstance(module, torch.nn.Module) 194 | assert not isinstance(module, torch.jit.ScriptModule) 195 | assert isinstance(inputs, (tuple, list)) 196 | 197 | # Register hooks. 198 | entries = [] 199 | nesting = [0] 200 | def pre_hook(_mod, _inputs): 201 | nesting[0] += 1 202 | def post_hook(mod, _inputs, outputs): 203 | nesting[0] -= 1 204 | if nesting[0] <= max_nesting: 205 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 206 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 207 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 208 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 209 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 210 | 211 | # Run module. 212 | outputs = module(*inputs) 213 | for hook in hooks: 214 | hook.remove() 215 | 216 | # Identify unique outputs, parameters, and buffers. 217 | tensors_seen = set() 218 | for e in entries: 219 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 220 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 221 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 222 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 223 | 224 | # Filter out redundant entries. 225 | if skip_redundant: 226 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 227 | 228 | # Construct table. 229 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 230 | rows += [['---'] * len(rows[0])] 231 | param_total = 0 232 | buffer_total = 0 233 | submodule_names = {mod: name for name, mod in module.named_modules()} 234 | for e in entries: 235 | name = '' if e.mod is module else submodule_names[e.mod] 236 | param_size = sum(t.numel() for t in e.unique_params) 237 | buffer_size = sum(t.numel() for t in e.unique_buffers) 238 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] 239 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 240 | rows += [[ 241 | name + (':0' if len(e.outputs) >= 2 else ''), 242 | str(param_size) if param_size else '-', 243 | str(buffer_size) if buffer_size else '-', 244 | (output_shapes + ['-'])[0], 245 | (output_dtypes + ['-'])[0], 246 | ]] 247 | for idx in range(1, len(e.outputs)): 248 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 249 | param_total += param_size 250 | buffer_total += buffer_size 251 | rows += [['---'] * len(rows[0])] 252 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 253 | 254 | # Print table. 255 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 256 | print() 257 | for row in rows: 258 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 259 | print() 260 | return outputs 261 | 262 | #---------------------------------------------------------------------------- 263 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import dnnlib as dnnlib 16 | import traceback 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | activation_funcs = { 24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 33 | } 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | _inited = False 38 | _plugin = None 39 | _null_tensor = torch.empty([0]) 40 | 41 | def _init(): 42 | global _inited, _plugin 43 | if not _inited: 44 | _inited = True 45 | sources = ['bias_act.cpp', 'bias_act.cu'] 46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 47 | try: 48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 49 | except: 50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 51 | return _plugin is not None 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 56 | r"""Fused bias and activation function. 57 | 58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 59 | and scales the result by `gain`. Each of the steps is optional. In most cases, 60 | the fused op is considerably more efficient than performing the same calculation 61 | using standard PyTorch ops. It supports first and second order gradients, 62 | but not third order gradients. 63 | 64 | Args: 65 | x: Input activation tensor. Can be of any shape. 66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 67 | as `x`. The shape must be known, and it must match the dimension of `x` 68 | corresponding to `dim`. 69 | dim: The dimension in `x` corresponding to the elements of `b`. 70 | The value of `dim` is ignored if `b` is not specified. 71 | act: Name of the activation function to evaluate, or `"linear"` to disable. 72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 73 | See `activation_funcs` for a full list. `None` is not allowed. 74 | alpha: Shape parameter for the activation function, or `None` to use the default. 75 | gain: Scaling factor for the output tensor, or `None` to use default. 76 | See `activation_funcs` for the default scaling of each activation function. 77 | If unsure, consider specifying 1. 78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 79 | the clamping (default). 80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 81 | 82 | Returns: 83 | Tensor of the same shape and datatype as `x`. 84 | """ 85 | assert isinstance(x, torch.Tensor) 86 | assert impl in ['ref', 'cuda'] 87 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | @misc.profiled_function 94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 96 | """ 97 | assert isinstance(x, torch.Tensor) 98 | assert clamp is None or clamp >= 0 99 | spec = activation_funcs[act] 100 | alpha = float(alpha if alpha is not None else spec.def_alpha) 101 | gain = float(gain if gain is not None else spec.def_gain) 102 | clamp = float(clamp if clamp is not None else -1) 103 | 104 | # Add bias. 105 | if b is not None: 106 | assert isinstance(b, torch.Tensor) and b.ndim == 1 107 | assert 0 <= dim < x.ndim 108 | assert b.shape[0] == x.shape[dim] 109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 110 | 111 | # Evaluate activation function. 112 | alpha = float(alpha) 113 | x = spec.func(x, alpha=alpha) 114 | 115 | # Scale by gain. 116 | gain = float(gain) 117 | if gain != 1: 118 | x = x * gain 119 | 120 | # Clamp. 121 | if clamp >= 0: 122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 123 | return x 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | _bias_act_cuda_cache = dict() 128 | 129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 130 | """Fast CUDA implementation of `bias_act()` using custom ops. 131 | """ 132 | # Parse arguments. 133 | assert clamp is None or clamp >= 0 134 | spec = activation_funcs[act] 135 | alpha = float(alpha if alpha is not None else spec.def_alpha) 136 | gain = float(gain if gain is not None else spec.def_gain) 137 | clamp = float(clamp if clamp is not None else -1) 138 | 139 | # Lookup from cache. 140 | key = (dim, act, alpha, gain, clamp) 141 | if key in _bias_act_cuda_cache: 142 | return _bias_act_cuda_cache[key] 143 | 144 | # Forward op. 145 | class BiasActCuda(torch.autograd.Function): 146 | @staticmethod 147 | def forward(ctx, x, b): # pylint: disable=arguments-differ 148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 149 | x = x.contiguous(memory_format=ctx.memory_format) 150 | b = b.contiguous() if b is not None else _null_tensor 151 | y = x 152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 154 | ctx.save_for_backward( 155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 157 | y if 'y' in spec.ref else _null_tensor) 158 | return y 159 | 160 | @staticmethod 161 | def backward(ctx, dy): # pylint: disable=arguments-differ 162 | dy = dy.contiguous(memory_format=ctx.memory_format) 163 | x, b, y = ctx.saved_tensors 164 | dx = None 165 | db = None 166 | 167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 168 | dx = dy 169 | if act != 'linear' or gain != 1 or clamp >= 0: 170 | dx = BiasActCudaGrad.apply(dy, x, b, y) 171 | 172 | if ctx.needs_input_grad[1]: 173 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 174 | 175 | return dx, db 176 | 177 | # Backward op. 178 | class BiasActCudaGrad(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 183 | ctx.save_for_backward( 184 | dy if spec.has_2nd_grad else _null_tensor, 185 | x, b, y) 186 | return dx 187 | 188 | @staticmethod 189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 191 | dy, x, b, y = ctx.saved_tensors 192 | d_dy = None 193 | d_x = None 194 | d_b = None 195 | d_y = None 196 | 197 | if ctx.needs_input_grad[0]: 198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 199 | 200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 202 | 203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 205 | 206 | return d_dy, d_x, d_b, d_y 207 | 208 | # Add to cache. 209 | _bias_act_cuda_cache[key] = BiasActCuda 210 | return BiasActCuda 211 | 212 | #---------------------------------------------------------------------------- 213 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import warnings 13 | import contextlib 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | 25 | @contextlib.contextmanager 26 | def no_weight_gradients(): 27 | global weight_gradients_disabled 28 | old = weight_gradients_disabled 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 54 | return True 55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 56 | return False 57 | 58 | def _tuple_of_ints(xs, ndim): 59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 60 | assert len(xs) == ndim 61 | assert all(isinstance(x, int) for x in xs) 62 | return xs 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | _conv2d_gradfix_cache = dict() 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | if not transpose: 112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 113 | else: # transpose 114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 115 | ctx.save_for_backward(input, weight) 116 | return output 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | input, weight = ctx.saved_tensors 121 | grad_input = None 122 | grad_weight = None 123 | grad_bias = None 124 | 125 | if ctx.needs_input_grad[0]: 126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 128 | assert grad_input.shape == input.shape 129 | 130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 131 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 132 | assert grad_weight.shape == weight_shape 133 | 134 | if ctx.needs_input_grad[2]: 135 | grad_bias = grad_output.sum([0, 2, 3]) 136 | 137 | return grad_input, grad_weight, grad_bias 138 | 139 | # Gradient with respect to the weights. 140 | class Conv2dGradWeight(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, grad_output, input): 143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 146 | assert grad_weight.shape == weight_shape 147 | ctx.save_for_backward(grad_output, input) 148 | return grad_weight 149 | 150 | @staticmethod 151 | def backward(ctx, grad2_grad_weight): 152 | grad_output, input = ctx.saved_tensors 153 | grad2_grad_output = None 154 | grad2_input = None 155 | 156 | if ctx.needs_input_grad[0]: 157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 158 | assert grad2_grad_output.shape == grad_output.shape 159 | 160 | if ctx.needs_input_grad[1]: 161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 163 | assert grad2_input.shape == input.shape 164 | 165 | return grad2_grad_output, grad2_input 166 | 167 | _conv2d_gradfix_cache[key] = Conv2d 168 | return Conv2d 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | w = w.flip([2, 3]) 37 | 38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 42 | if out_channels <= 4 and groups == 1: 43 | in_shape = x.shape 44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 46 | else: 47 | x = x.to(memory_format=torch.contiguous_format) 48 | w = w.to(memory_format=torch.contiguous_format) 49 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 50 | return x.to(memory_format=torch.channels_last) 51 | 52 | # Otherwise => execute using conv2d_gradfix. 53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 54 | return op(x, w, stride=stride, padding=padding, groups=groups) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | @misc.profiled_function 59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 60 | r"""2D convolution with optional up/downsampling. 61 | 62 | Padding is performed only once at the beginning, not between the operations. 63 | 64 | Args: 65 | x: Input tensor of shape 66 | `[batch_size, in_channels, in_height, in_width]`. 67 | w: Weight tensor of shape 68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 70 | calling upfirdn2d.setup_filter(). None = identity (default). 71 | up: Integer upsampling factor (default: 1). 72 | down: Integer downsampling factor (default: 1). 73 | padding: Padding with respect to the upsampled image. Can be a single number 74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 75 | (default: 0). 76 | groups: Split input channels into N groups (default: 1). 77 | flip_weight: False = convolution, True = correlation (default: True). 78 | flip_filter: False = convolution, True = correlation (default: False). 79 | 80 | Returns: 81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 82 | """ 83 | # Validate arguments. 84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 87 | assert isinstance(up, int) and (up >= 1) 88 | assert isinstance(down, int) and (down >= 1) 89 | assert isinstance(groups, int) and (groups >= 1) 90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 91 | fw, fh = _get_filter_size(f) 92 | px0, px1, py0, py1 = _parse_padding(padding) 93 | 94 | # Adjust padding to account for up/downsampling. 95 | if up > 1: 96 | px0 += (fw + up - 1) // 2 97 | px1 += (fw - up) // 2 98 | py0 += (fh + up - 1) // 2 99 | py1 += (fh - up) // 2 100 | if down > 1: 101 | px0 += (fw - down + 1) // 2 102 | px1 += (fw - down) // 2 103 | py0 += (fh - down + 1) // 2 104 | py1 += (fh - down) // 2 105 | 106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 107 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 110 | return x 111 | 112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 113 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 116 | return x 117 | 118 | # Fast path: downsampling only => use strided convolution. 119 | if down > 1 and up == 1: 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 122 | return x 123 | 124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 125 | if up > 1: 126 | if groups == 1: 127 | w = w.transpose(0, 1) 128 | else: 129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 130 | w = w.transpose(1, 2) 131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 132 | px0 -= kw - 1 133 | px1 -= kw - up 134 | py0 -= kh - 1 135 | py1 -= kh - up 136 | pxt = max(min(-px0, -px1), 0) 137 | pyt = max(min(-py0, -py1), 0) 138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 140 | if down > 1: 141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 142 | return x 143 | 144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 145 | if up == 1 and down == 1: 146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 148 | 149 | # Fallback: Generic reference implementation. 150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 152 | if down > 1: 153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 154 | return x 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def grid_sample(input, grid): 28 | if _should_use_custom_op(): 29 | return _GridSample2dForward.apply(input, grid) 30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _should_use_custom_op(): 35 | if not enabled: 36 | return False 37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 38 | return True 39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 40 | return False 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | class _GridSample2dForward(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input, grid): 47 | assert input.ndim == 4 48 | assert grid.ndim == 4 49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 50 | ctx.save_for_backward(input, grid) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, grid = ctx.saved_tensors 56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 57 | return grad_input, grad_grid 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class _GridSample2dBackward(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, grad_output, input, grid): 64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import stylegan2_ada.dnnlib as dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /stylegan2_ada/torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /stylegan2_ada/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | --------------------------------------------------------------------------------