├── .gitignore ├── LICENSE ├── LICENSE-DISCRIMINATOR ├── LICENSE-LPIPS ├── LICENSE-NVIDIA ├── data ├── __init__.py ├── create_beton_file.py ├── datamodules.py └── datasets.py ├── environment.yml ├── example_confs ├── ema_vqvae.yaml ├── entropy_vqvae.yaml ├── gumbel_vqgan.yaml ├── standard_vqvae.yaml └── standard_vqvae_reinit.yaml ├── readme.md ├── setup.py └── vqvae ├── __init__.py ├── common_utils.py ├── evaluate.py ├── model.py ├── modules ├── __init__.py ├── abstract_modules │ ├── __init__.py │ ├── base_autoencoder.py │ └── base_quantizer.py ├── autoencoder.py ├── loss │ ├── __init__.py │ ├── loss.py │ ├── lpips_pytorch │ │ ├── __init__.py │ │ └── modules │ │ │ ├── __init__.py │ │ │ ├── lpips.py │ │ │ ├── networks.py │ │ │ └── utils.py │ └── stylegan2_discriminator │ │ ├── __init__.py │ │ ├── discriminator.py │ │ └── utils │ │ ├── __init__.py │ │ ├── custom_ops.py │ │ ├── dnnlib │ │ ├── __init__.py │ │ └── util.py │ │ ├── misc.py │ │ ├── ops │ │ ├── __init__.py │ │ ├── bias_act.cpp │ │ ├── bias_act.cu │ │ ├── bias_act.h │ │ ├── bias_act.py │ │ ├── conv2d_gradfix.py │ │ ├── conv2d_resample.py │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.cu │ │ ├── upfirdn2d.h │ │ └── upfirdn2d.py │ │ └── persistence.py └── vector_quantizers.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | .idea/ 10 | /vqvae.egg-info/ 11 | /data/ffcv_toy_example/ 12 | /data/toy_example/ 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Dario Serez 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE-DISCRIMINATOR: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /LICENSE-LPIPS: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Sou Uchida 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /LICENSE-NVIDIA: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for StyleGAN3 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SerezD/vqvae-vqgan-pytorch-lightning/277d909d27e55c36eb410f23684bfdf68f8f38dc/data/__init__.py -------------------------------------------------------------------------------- /data/create_beton_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from ffcv.fields import RGBImageField 5 | from ffcv_pl.generate_dataset import create_beton_wrapper 6 | 7 | from data.datasets import ImageDataset 8 | 9 | 10 | def get_args(): 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Define an Image Dataset using ffcv for fast data loading') 14 | 15 | parser.add_argument('--max_resolution', type=int, default=256) 16 | parser.add_argument('--output_folder', type=str, required=True) 17 | parser.add_argument('--train_folder', type=str, default=None) 18 | parser.add_argument('--val_folder', type=str, default=None) 19 | parser.add_argument('--test_folder', type=str, default=None) 20 | parser.add_argument('--predict_folder', type=str, default=None) 21 | 22 | return parser.parse_args() 23 | 24 | 25 | def main(): 26 | 27 | args = get_args() 28 | 29 | if not os.path.exists(args.output_folder): 30 | os.makedirs(args.output_folder) 31 | 32 | if args.train_folder is not None: 33 | train_dataset = ImageDataset(folder=args.train_folder, image_size=args.max_resolution, ffcv=True) 34 | 35 | # https://docs.ffcv.io/api/fields.html 36 | fields = (RGBImageField(write_mode='jpg', max_resolution=args.max_resolution),) 37 | create_beton_wrapper(train_dataset, f"{args.output_folder}/train.beton", fields=fields) 38 | 39 | if args.val_folder is not None: 40 | val_dataset = ImageDataset(folder=args.val_folder, image_size=args.max_resolution, ffcv=True) 41 | 42 | # https://docs.ffcv.io/api/fields.html 43 | fields = (RGBImageField(write_mode='jpg', max_resolution=args.max_resolution),) 44 | create_beton_wrapper(val_dataset, f"{args.output_folder}/validation.beton", fields=fields) 45 | 46 | if args.test_folder is not None: 47 | test_dataset = ImageDataset(folder=args.test_folder, image_size=args.max_resolution, ffcv=True) 48 | 49 | # https://docs.ffcv.io/api/fields.html 50 | fields = (RGBImageField(write_mode='jpg', max_resolution=args.max_resolution),) 51 | create_beton_wrapper(test_dataset, f"{args.output_folder}/test.beton", fields=fields) 52 | 53 | if args.predict_folder is not None: 54 | predict_dataset = ImageDataset(folder=args.predict_folder, image_size=args.max_resolution, ffcv=True) 55 | 56 | # https://docs.ffcv.io/api/fields.html 57 | fields = (RGBImageField(write_mode='jpg', max_resolution=args.max_resolution),) 58 | create_beton_wrapper(predict_dataset, f"{args.output_folder}/predict.beton", fields=fields) 59 | 60 | 61 | if __name__ == '__main__': 62 | 63 | main() 64 | -------------------------------------------------------------------------------- /data/datamodules.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torch.utils.data import DataLoader 3 | 4 | from data.datasets import ImageDataset 5 | 6 | 7 | class ImageDataModule(pl.LightningDataModule): 8 | 9 | def __init__(self, image_size: int, batch_size: int, num_workers: int, train_folder: str = None, 10 | val_folder: str = None, test_folder: str = None, predict_folder: str = None): 11 | """ 12 | :param image_size: all images in the dataset will be resized like this (squared). 13 | :param batch_size: the same batch size is specified for every loader (train, test). 14 | :param num_workers: the same number of workers is specified for every loader (train, test). 15 | :param train_folder: the folder containing train images. 16 | :param test_folder: the folder containing test images. 17 | """ 18 | 19 | super().__init__() 20 | 21 | self.train_folder = train_folder 22 | self.val_folder = val_folder 23 | self.test_folder = test_folder 24 | self.predict_folder = predict_folder 25 | 26 | self.image_size = image_size 27 | self.batch_size = batch_size 28 | self.num_workers = num_workers 29 | 30 | self.train = None 31 | self.val = None 32 | self.test = None 33 | self.predict = None 34 | 35 | def setup(self, stage: str): 36 | 37 | if stage == 'fit': 38 | if self.train_folder is not None: 39 | self.train = ImageDataset(self.train_folder, self.image_size, ffcv=False) 40 | if self.val_folder is not None: 41 | self.val = ImageDataset(self.val_folder, self.image_size, ffcv=False) 42 | elif stage == 'validate': 43 | if self.val_folder is not None: 44 | self.val = ImageDataset(self.val_folder, self.image_size, ffcv=False) 45 | elif stage == 'test': 46 | if self.test_folder is not None: 47 | self.test = ImageDataset(self.test_folder, self.image_size, ffcv=False) 48 | elif stage == 'predict': 49 | if self.predict_folder is not None: 50 | self.predict = ImageDataset(self.predict_folder, self.image_size, ffcv=False) 51 | else: 52 | pass 53 | 54 | def train_dataloader(self): 55 | if self.train is None: 56 | pass 57 | return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers, 58 | pin_memory=True, shuffle=True, drop_last=True) 59 | 60 | def val_dataloader(self): 61 | if self.val is None: 62 | pass 63 | return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers, 64 | pin_memory=True, shuffle=False, drop_last=False) 65 | 66 | def test_dataloader(self): 67 | if self.test is None: 68 | pass 69 | return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers, 70 | pin_memory=True, shuffle=False, drop_last=False) 71 | 72 | def predict_dataloader(self): 73 | if self.predict is None: 74 | pass 75 | return DataLoader(self.predict, batch_size=self.batch_size, num_workers=self.num_workers, 76 | pin_memory=True, shuffle=False, drop_last=False) 77 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from torchvision.transforms import Compose, ToTensor, Resize 3 | from torch.utils.data import Dataset 4 | 5 | from PIL import Image 6 | 7 | 8 | class ImageDataset(Dataset): 9 | 10 | def __init__(self, folder: str, image_size: int, ffcv: bool = False): 11 | 12 | self.samples = sorted(list(pathlib.Path(folder).rglob('*.png')) + list(pathlib.Path(folder).rglob('*.jpg')) + 13 | list(pathlib.Path(folder).rglob('*.bmp')) + list(pathlib.Path(folder).rglob('*.JPEG'))) 14 | self.ffcv = ffcv 15 | 16 | self.transforms = Compose([ToTensor(), Resize((image_size, image_size), antialias=True)]) 17 | 18 | def __len__(self): 19 | return len(self.samples) 20 | 21 | def __getitem__(self, idx): 22 | 23 | # path to string 24 | image_path = self.samples[idx].absolute().as_posix() 25 | 26 | image = Image.open(image_path).convert('RGB') 27 | 28 | return (image, ) if self.ffcv else self.transforms(image) 29 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: vqvae 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - nvidia 6 | - defaults 7 | dependencies: 8 | - python<3.11 9 | - pytorch-cuda 10 | - torchaudio>=2.0.1 11 | - pytorch>=2.0.0 12 | - torchvision>=0.15.1 13 | - cupy 14 | - pkg-config 15 | - libjpeg-turbo>=2.1.4 16 | - opencv 17 | - numba==0.56.2 18 | - pytorch-lightning>=2.0.0 19 | - einops 20 | - kornia 21 | - cudatoolkit-dev 22 | - wandb 23 | - scipy 24 | - ninja 25 | - matplotlib 26 | - pip 27 | - pip: 28 | - ffcv>=1.0.0 29 | - ffcv_pl 30 | - lightning-utilities 31 | - torch-fidelity 32 | - scheduling_utils 33 | -------------------------------------------------------------------------------- /example_confs/ema_vqvae.yaml: -------------------------------------------------------------------------------- 1 | image_size: 256 2 | 3 | # will encode to 16x16 resolution (256 / 2^4) 4 | autoencoder: 5 | channels: 128 6 | num_res_blocks: 2 7 | channel_multipliers: 8 | - 1 9 | - 2 10 | - 2 11 | - 4 12 | 13 | quantizer: 14 | num_embeddings: 4096 15 | embedding_dim: 256 16 | type: ema 17 | params: 18 | commitment_cost: 0.25 19 | decay: 0.95 20 | epsilon: 1e-5 21 | reinit_every_n_epochs: # may be useful in Standard or EMA modes 22 | 23 | training: 24 | cumulative_bs: 256 25 | base_lr: 1e-4 # refers to LR for cumulative_bs = 256. Will scale automatically if bs is increased/reduced 26 | betas: 27 | - 0.0 28 | - 0.99 29 | eps: 1e-8 30 | weight_decay: 1e-4 31 | decay_epochs: 250 32 | max_epochs: 300 33 | -------------------------------------------------------------------------------- /example_confs/entropy_vqvae.yaml: -------------------------------------------------------------------------------- 1 | image_size: 256 2 | 3 | # will encode to 16x16 resolution (256 / 2^4) 4 | autoencoder: 5 | channels: 128 6 | num_res_blocks: 2 7 | channel_multipliers: 8 | - 1 9 | - 2 10 | - 2 11 | - 4 12 | 13 | quantizer: 14 | num_embeddings: 1024 15 | embedding_dim: 256 16 | type: entropy 17 | params: 18 | ent_loss_ratio: 0.1 19 | ent_temperature: 0.01 20 | ent_loss_type: softmax 21 | commitment_cost: 0.25 22 | reinit_every_n_epochs: # may be useful in Standard or EMA modes 23 | 24 | training: 25 | cumulative_bs: 256 26 | base_lr: 1e-4 # refers to LR for cumulative_bs = 256. Will scale automatically if bs is increased/reduced 27 | betas: 28 | - 0.0 29 | - 0.99 30 | eps: 1e-8 31 | weight_decay: 1e-4 32 | decay_epochs: 250 33 | max_epochs: 300 34 | -------------------------------------------------------------------------------- /example_confs/gumbel_vqgan.yaml: -------------------------------------------------------------------------------- 1 | image_size: 256 2 | 3 | # will encode to 16x16 resolution (256 / 2**4) 4 | autoencoder: 5 | channels: 128 6 | num_res_blocks: 2 7 | channel_multipliers: 8 | - 1 9 | - 2 10 | - 2 11 | - 4 12 | 13 | quantizer: 14 | num_embeddings: 1024 # codebook dimension 15 | embedding_dim: 256 # single vector dimension 16 | type: gumbel # one of ['standard', 'ema', 'gumbel', 'entropy'] 17 | params: # Taken from DALL-E paper: https://arxiv.org/pdf/2102.12092.pdf - appendix A.2 18 | straight_through: False 19 | temp: 1.0 20 | kl_cost: 0.00859375 # 6.6 / [(256 * 256 * 3) / (16 * 16 * 1)] scale the initial value of 6.6 by the compression factor (768 in our case) 21 | kl_warmup_epochs: 0.48 # 5000 / 3M total updates in DALL-E = 0.0016. In this case -> 300 epochs * 0.0016 = 0.48 22 | temp_decay_epochs: 15 # 150000 / 3M total updates in DALL-E = 0.05. In this case -> 300 epochs * 0.05 = 15 23 | temp_final: 0.0625 # 1./16. 24 | reinit_every_n_epochs: # may be useful in Standard or EMA modes 25 | 26 | loss: 27 | l1_weight: 0.8 28 | l2_weight: 0.2 29 | perc_weight: 1.0 30 | adversarial_params: 31 | start_epoch: 100 # if None - don't use Discriminator (for ablation purposes) 32 | loss_type: non-saturating 33 | g_weight: 0.1 # 0.8 if use adaptive else 0.1 (maskgit) 34 | use_adaptive: False # compute g weight adaptively according to grad of last layers (then scaled by g_weight) 35 | r1_reg_weight: 10. # if None don't use r1 regularization (ablation purposes) 36 | r1_reg_every: 16 # steps 37 | 38 | training: 39 | cumulative_bs: 256 40 | base_lr: 1e-4 # refers to LR for cumulative_bs = 256. Will scale automatically if bs is increased/reduced 41 | betas: 42 | - 0.0 43 | - 0.99 44 | eps: 1e-8 45 | weight_decay: 1e-4 46 | decay_epochs: 250 47 | max_epochs: 300 48 | -------------------------------------------------------------------------------- /example_confs/standard_vqvae.yaml: -------------------------------------------------------------------------------- 1 | image_size: 256 2 | 3 | # will encode to 16x16 resolution (256 / 2^4) 4 | autoencoder: 5 | channels: 128 6 | num_res_blocks: 2 7 | channel_multipliers: 8 | - 1 9 | - 2 10 | - 2 11 | - 4 12 | 13 | quantizer: 14 | num_embeddings: 1024 15 | embedding_dim: 256 16 | type: standard 17 | params: 18 | commitment_cost: 0.25 19 | reinit_every_n_epochs: # may be useful in Standard or EMA modes 20 | 21 | training: 22 | cumulative_bs: 256 23 | base_lr: 1e-4 # refers to LR for cumulative_bs = 256. Will scale automatically if bs is increased/reduced 24 | betas: 25 | - 0.0 26 | - 0.99 27 | eps: 1e-8 28 | weight_decay: 1e-4 29 | decay_epochs: 250 30 | max_epochs: 300 31 | -------------------------------------------------------------------------------- /example_confs/standard_vqvae_reinit.yaml: -------------------------------------------------------------------------------- 1 | image_size: 256 2 | 3 | # will encode to 16x16 resolution (256 / 2^4) 4 | autoencoder: 5 | channels: 128 6 | num_res_blocks: 2 7 | channel_multipliers: 8 | - 1 9 | - 2 10 | - 2 11 | - 4 12 | 13 | quantizer: 14 | num_embeddings: 1024 15 | embedding_dim: 256 16 | type: standard 17 | params: 18 | commitment_cost: 0.25 19 | reinit_every_n_epochs: 10 # may be useful in Standard or EMA modes 20 | 21 | training: 22 | cumulative_bs: 256 23 | base_lr: 1e-4 # refers to LR for cumulative_bs = 256. Will scale automatically if bs is increased/reduced 24 | betas: 25 | - 0.0 26 | - 0.99 27 | eps: 1e-8 28 | weight_decay: 1e-4 29 | decay_epochs: 250 30 | max_epochs: 300 31 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # VQ-VAE/GAN Pytorch Lightning 2 | 3 | Pytorch lightning implementation of both VQVAE/VQGAN, with different quantization algorithms. 4 | Uses [FFCV](https://github.com/libffcv/ffcv) for fast data loading and [WandB](https://github.com/wandb/wandb) 5 | for logging. 6 | 7 | ### Acknowledgments and Citations 8 | 9 | Original vqvae paper: https://arxiv.org/abs/1711.00937 10 | Original vqgan paper: https://arxiv.org/abs/2012.09841 11 | 12 | Original vqvae code: https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py 13 | Original vqgan code: https://github.com/CompVis/taming-transformers 14 | 15 | Some architectural improvements are taken by: 16 | MaskGit: 17 | - paper https://arxiv.org/abs/2202.04200 18 | - code https://github.com/google-research/maskgit 19 | 20 | Improved VQGAN: https://arxiv.org/abs/2110.04627 21 | 22 | Perceptual Loss part cloned from: https://github.com/S-aiueo32/lpips-pytorch/tree/master 23 | 24 | Discriminator cloned from: https://github.com/NVlabs/stylegan2-ada-pytorch 25 | Discriminator Losses (hinge / non-saturating): https://github.com/google-research/maskgit 26 | 27 | Quantization Algorithms: 28 | - Standard and EMA update: Original VQVAE paper. 29 | - Gumbel Softmax: code taken from https://github.com/karpathy/deep-vector-quantization, 30 | parameters from DALL-E paper: https://arxiv.org/abs/2102.12092. Also check: https://arxiv.org/abs/1611.01144 31 | for a theoretical understanding. 32 | - "Entropy" Quantizer: code taken from https://github.com/google-research/maskgit 33 | 34 | Fast Data Loading: 35 | - FFCV: https://github.com/libffcv/ffcv 36 | - FFCV_PL: https://github.com/SerezD/ffcv_pytorch_lightning 37 | 38 | ### Installation 39 | 40 | For fast solving, I suggest to use libmamba: 41 | https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community 42 | 43 | *Note: Check the `pytorch-cuda` version in `environment.yml` to ensure it is compatible with your cuda version.* 44 | 45 | ``` 46 | # Dependencies Install 47 | conda env create --file environment.yml 48 | conda activate vqvae 49 | 50 | # package install (after cloning) 51 | pip install . 52 | ``` 53 | 54 | #### Stylegan custom ops 55 | 56 | StyleGan discriminator uses custom cuda operations, written by the NVIDIA team to speed up training. 57 | This requires to install NVIDIA-CUDA TOOLKIT: https://github.com/NVlabs/stylegan3/blob/main/docs/troubleshooting.md 58 | 59 | In this repo, instead of NVIDIA-CUDA TOOLKIT, the `environment.yml` installs: https://anaconda.org/conda-forge/cudatoolkit-dev 60 | I found this to be an easier option, and apparently everything works fine. 61 | 62 | ### Datasets and DataLoaders 63 | 64 | This repository allows for both fast (`FFCV`) and standard (`pytorch`) data loading. 65 | 66 | In each case, your dataset can be composed of images in `.png .jpg .bmp .JPEG` formats. 67 | The dataset structure must be like the following: 68 | ``` 69 | 🗂 path/to/dataset/ 70 | 📂 train/ 71 | ┣ 000.jpeg 72 | ┣ 001.jpg 73 | ┗ 002.png 74 | 📂 validation/ 75 | ┣ 003.jpeg 76 | ┣ 004.bmp 77 | ┗ 005.png 78 | 📂 test/ 79 | ┣ 006.jpeg 80 | ┣ 007.jpg 81 | ┗ 008.bmp 82 | ``` 83 | 84 | If you want to use `FFCV`, you must first create the `.beton` files. For this you can use the `create_beton_file.py` script 85 | int the `/data` directory. 86 | 87 | ``` 88 | # example 89 | # creates 2 beton files (one for val and one for training) 90 | # in the /home/datasets/examples/beton_dataset directory. 91 | # the max resolution of the preprocessed images will be 256x256 92 | 93 | python ./data/create_beton_file.py --max_resolution 256 / 94 | --output_folder "/home/datasets/examples/beton_dataset" / 95 | --train_folder "/home/datasets/examples/train" / 96 | --val_folder "/home/datasets/examples/validation" 97 | ``` 98 | 99 | For more information on fast loading, check: 100 | - FFCV: https://github.com/libffcv/ffcv 101 | - FFCV_PL: https://github.com/SerezD/ffcv_pytorch_lightning 102 | 103 | 104 | ### Configuration Files 105 | 106 | The configuration file `.yaml` provides all the details on the type of autoencoder that 107 | you want to train (check the folder "./example_confs"). 108 | 109 | ### Training 110 | 111 | Once dataset and configuration file are created, run training script like: 112 | ``` 113 | python ./vqvae/train.py --params_file "./example_confs/standard_vqvae_cb1024.yaml" \ 114 | --dataloader ffcv \ # uses ffcv data-loader 115 | --dataset_path "/home/datasets/examples/" \ # contains train/validation .beton file 116 | --save_path "./runs/" \ 117 | --run_name vqvae_standard_quantization \ 118 | --seed 1234 \ # fix seed for reproducibility 119 | --logging \ # will log results to wandb 120 | --workers 8 121 | ``` 122 | 123 | ### Evaluation 124 | 125 | To evaluate a pre-trained model, run: 126 | ``` 127 | python ./vqvae/evaluate.py --params_file "./example_confs/standard_vqvae_cb1024.yaml" \ # config of pretrained model 128 | --dataloader ffcv \ # uses ffcv data-loader 129 | --dataset_path "/home/datasets/examples/" \ # contains test.beton file 130 | --batch_size 64 \ # evaluation is done on single gpu 131 | --seed 1234 \ # fix seed for reproducibility 132 | --loading_path "/home/runs/standard_vqvae_cb1024/last.ckpt" \ # checkpoint file 133 | --workers 8 134 | ``` 135 | 136 | The Evaluation process is based on the `torchmetrics` library (https://lightning.ai/docs/torchmetrics/stable/). For each run, 137 | computed measures are L2, PSNR, SSIM, rFID for reconstruction and Perplexity, Codebook usage on the whole test set for quantization. 138 | 139 | ### Attempts to reproduce the original VQGAN results on _ImageNet-1K_ 140 | 141 | Reproduction is really hard, mainly due to the high compression rate (256x256 to 16x16) and relatively small 142 | codebook size (1024 indices). 143 | 144 | The pretrained models and configuration files used can be downloaded at 145 | [https://huggingface.co/SerezD/vqvae-vqgan-pytorch-lightning](https://huggingface.co/SerezD/vqvae-vqgan-pytorch-lightning) 146 | 147 | 148 | | Run Name | Codebook Usage ↑ | Perplexity ↑ | L2 ↓ | SSIM ↑ | PSNR ↑ | rFID ↓ | # (trainable) params | 149 | |-------------------------------|-----------------:|-------------:|--------|--------|--------|--------|---------------------:| 150 | | original VQGAN (Esser et Al.) | - | - | - | - | - | 7.94 | - | 151 | | Maskgit VQGAN (Cheng et Al.) | - | - | - | - | - | 2.28 | - | 152 | | Gumbel Reproduction | 99.61 % | 892.00 | 0.0075 | 0.61 | 21.23 | 6.30 | 72.5 M | 153 | | Entropy Reproduction | 99.70 % | 896.78 | 0.0082 | 0.62 | 20.82 | 6.17 | 71.1 M | 154 | 155 | 156 | _Note:_ For training, NVIDIA A100 GPUs with Tensor Core have been used. 157 | 158 | ### Details on Quantization Algorithms 159 | 160 | Classic or EMA VQ-VAE are known to encounter codebook-collapse issues, where only a subset of the codebook indices 161 | is used. See for example: _Theory and Experiments on 162 | Vector Quantized Autoencoders_ (https://arxiv.org/pdf/1805.11063.pdf) 163 | 164 | To avoid collapse, some solutions have been proposed (and are implemented in this repo): 165 | 1. Re-initialize the unused codebook indices every _n_ epochs. Can be applied with standard 166 | or EMA Vector Quantization. 167 | in the Gumbel Softmax and Entropy Quantization algorithms. 168 | 2. Totally change the Quantization algorithm, adding some regularization term (Gumbel, Entropy) to increase the entropy 169 | in the codebook distribution. 170 | 171 | ### Details on Discriminator part 172 | 173 | In general, it is better to wait as long as possible before Discriminator kicks in. 174 | Check these issues in the original VQGAN repo: 175 | - https://github.com/CompVis/taming-transformers/issues/31 176 | - https://github.com/CompVis/taming-transformers/issues/61 177 | - https://github.com/CompVis/taming-transformers/issues/93 178 | 179 | In the reproduction, Discriminator starts only after 100 epochs. The training continues until possible. At a certain 180 | point, the loss collapses (typical behavior in GANs). 181 | 182 | I found that both R1 regularization and the adaptive generator weight may help in preventing collapse. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='vqvae', 4 | py_modules=["vqvae", "data"] 5 | ) 6 | -------------------------------------------------------------------------------- /vqvae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SerezD/vqvae-vqgan-pytorch-lightning/277d909d27e55c36eb410f23684bfdf68f8f38dc/vqvae/__init__.py -------------------------------------------------------------------------------- /vqvae/common_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import yaml 4 | 5 | from ffcv.fields.rgb_image import CenterCropRGBImageDecoder 6 | from ffcv.loader import OrderOption 7 | from ffcv.transforms import ToTensor, ToTorchImage 8 | 9 | from ffcv_pl.data_loading import FFCVDataModule 10 | from ffcv_pl.ffcv_utils.augmentations import DivideImage255 11 | from ffcv_pl.ffcv_utils.utils import FFCVPipelineManager 12 | 13 | from data.datamodules import ImageDataModule 14 | 15 | 16 | def set_matmul_precision(): 17 | """ 18 | If using Ampere Gpus enable using tensor cores. 19 | Don't know exactly which other devices can benefit from this, but torch should throw a warning in case. 20 | Docs: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html 21 | """ 22 | 23 | gpu_cores = os.popen('nvidia-smi -L').readlines()[0] 24 | 25 | if 'A100' in gpu_cores: 26 | torch.set_float32_matmul_precision('high') 27 | print('[INFO] set matmul precision "high"') 28 | 29 | 30 | def get_model_conf(filepath: str): 31 | # load params 32 | with open(filepath, 'r', encoding='utf-8') as stream: 33 | params = yaml.safe_load(stream) 34 | 35 | return params 36 | 37 | 38 | def get_datamodule(loader_type: str, dirpath: str, image_size: int, batch_size: int, workers: int, seed: int, 39 | is_dist: bool, mode: str = 'train'): 40 | 41 | if not os.path.isdir(dirpath): 42 | raise FileNotFoundError(f"dataset path not found: {dirpath}") 43 | 44 | else: 45 | 46 | if loader_type == 'standard': 47 | 48 | if mode == 'train': 49 | train_folder = f'{dirpath}train/' 50 | val_folder = f'{dirpath}validation/' 51 | return ImageDataModule(image_size, batch_size, workers, train_folder, val_folder) 52 | else: 53 | test_folder = f'{dirpath}test/' 54 | return ImageDataModule(image_size, batch_size, workers, test_folder=test_folder) 55 | 56 | elif loader_type == 'ffcv': 57 | 58 | if mode == 'train': 59 | 60 | train_manager = FFCVPipelineManager( 61 | file_path=f'{dirpath}train.beton', 62 | pipeline_transforms=[ 63 | [ 64 | CenterCropRGBImageDecoder((image_size, image_size), ratio=1), 65 | ToTensor(), 66 | ToTorchImage(), 67 | DivideImage255(dtype=torch.float16) 68 | ] 69 | ], 70 | ordering=OrderOption.RANDOM 71 | ) 72 | 73 | val_manager = FFCVPipelineManager( 74 | file_path=f'{dirpath}validation.beton', 75 | pipeline_transforms=[ 76 | [ 77 | CenterCropRGBImageDecoder((image_size, image_size), ratio=1), 78 | ToTensor(), 79 | ToTorchImage(), 80 | DivideImage255(dtype=torch.float16) 81 | ] 82 | ] 83 | ) 84 | 85 | return FFCVDataModule(batch_size, workers, is_dist, train_manager, val_manager, seed=seed) 86 | 87 | else: 88 | test_manager = FFCVPipelineManager( 89 | file_path=f'{dirpath}test.beton', 90 | pipeline_transforms=[ 91 | [ 92 | CenterCropRGBImageDecoder((image_size, image_size), ratio=1), 93 | ToTensor(), 94 | ToTorchImage(), 95 | DivideImage255(dtype=torch.float16) 96 | ] 97 | ] 98 | ) 99 | 100 | return FFCVDataModule(batch_size, workers, is_dist, test_manager=test_manager, seed=seed) 101 | 102 | else: 103 | raise ValueError(f"loader type not recognized: {loader_type}") 104 | -------------------------------------------------------------------------------- /vqvae/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pytorch_lightning as pl 3 | 4 | from vqvae.common_utils import get_model_conf, get_datamodule 5 | from vqvae.model import VQVAE 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument('--params_file', type=str, required=True, help='path to yaml file with model params') 12 | parser.add_argument('--dataloader', type=str, choices=['standard', 'ffcv'], default='standard', 13 | help='defines what type of dataloader to use.') 14 | parser.add_argument('--dataset_path', type=str, required=True, 15 | help='path to a dataset folder containing two sub-folders (validation / train) or beton files ' 16 | '(train.beton / validation.beton).') 17 | parser.add_argument('--batch_size', type=int, required=True, 18 | help='evaluation is on onr gpu, set batch size.') 19 | parser.add_argument('--seed', type=int, required=True, help='global random seed for reproducibility') 20 | parser.add_argument('--loading_path', type=str, required=True, 21 | help='path to checkpoint to load') 22 | parser.add_argument('--workers', type=int, help='num of parallel workers', default=1) 23 | 24 | return parser.parse_args() 25 | 26 | 27 | def main(): 28 | 29 | args = parse_args() 30 | conf = get_model_conf(args.params_file) 31 | 32 | # configuration params (assumes some env variables in case of multi-node setup) 33 | workers = int(args.workers) 34 | seed = int(args.seed) 35 | 36 | batch_size = args.batch_size 37 | 38 | pl.seed_everything(seed, workers=True) 39 | 40 | load_checkpoint_path = args.loading_path 41 | 42 | # model params 43 | image_size = int(conf['image_size']) 44 | ae_conf = conf['autoencoder'] 45 | q_conf = conf['quantizer'] 46 | 47 | # get model 48 | model = VQVAE.load_from_checkpoint(load_checkpoint_path, strict=False, image_size=image_size, ae_conf=ae_conf, 49 | q_conf=q_conf, l_conf=None, t_conf=None, init_cb=False, load_loss=False) 50 | 51 | # data loading (standard pytorch lightning or ffcv) 52 | datamodule = get_datamodule(args.dataloader, args.dataset_path, image_size, batch_size, 53 | workers, seed, is_dist=False, mode='test') 54 | 55 | # trainer 56 | trainer = pl.Trainer(strategy='ddp', accelerator='gpu', devices=1, precision='16-mixed', deterministic=True) 57 | print(f"[INFO] workers: {workers}") 58 | print(f"[INFO] batch size: {batch_size}") 59 | 60 | trainer.test(model=model, datamodule=datamodule) 61 | 62 | 63 | if __name__ == '__main__': 64 | 65 | main() 66 | -------------------------------------------------------------------------------- /vqvae/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SerezD/vqvae-vqgan-pytorch-lightning/277d909d27e55c36eb410f23684bfdf68f8f38dc/vqvae/modules/__init__.py -------------------------------------------------------------------------------- /vqvae/modules/abstract_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SerezD/vqvae-vqgan-pytorch-lightning/277d909d27e55c36eb410f23684bfdf68f8f38dc/vqvae/modules/abstract_modules/__init__.py -------------------------------------------------------------------------------- /vqvae/modules/abstract_modules/base_autoencoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | from kornia.augmentation import AugmentationSequential, Denormalize, Normalize, RandomHorizontalFlip, RandomResizedCrop 4 | 5 | 6 | class BaseVQVAE(ABC): 7 | """ 8 | Defines the methods that can be used to work with indices in two stage models, 9 | as well as preprocessing 10 | """ 11 | 12 | def __init__(self, image_size: int): 13 | 14 | super().__init__() 15 | 16 | self.image_size = image_size 17 | self.preprocess = Normalize(mean=torch.tensor((0.5, 0.5, 0.5)), std=torch.tensor((0.5, 0.5, 0.5))) 18 | self.postprocess = Denormalize(mean=torch.tensor((0.5, 0.5, 0.5)), std=torch.tensor((0.5, 0.5, 0.5))) 19 | 20 | self.training_augmentations = AugmentationSequential( 21 | RandomResizedCrop((image_size, image_size), scale=(0.7, 1.0), ratio=(1.0, 1.0)), 22 | RandomHorizontalFlip(), same_on_batch=False) 23 | 24 | # init values for child classes (may never be implemented) 25 | self.scheduler = None 26 | 27 | # codebook usage counts for re-init (train) or logging (validation) (may never be used) 28 | self.train_epoch_usage_count = None 29 | self.val_epoch_usage_count = None 30 | 31 | @torch.no_grad() 32 | def preprocess_batch(self, images: torch.Tensor, training: bool = False): 33 | """ 34 | :param images: batch of float32 tensors B, C, H, W assumed in range 0__1. 35 | :param training: if True, additionally applies self.training_augmentations 36 | (Random HFlip and Random Resized Crop) 37 | :return normalized images (-1., 1.) of shape B, C, H, W -> ready for the forward pass. 38 | """ 39 | 40 | # ensure 0 _ 1 values (no effect if correctly loaded) 41 | images = torch.clamp(images, 0., 1.) 42 | 43 | # training data augmentation 44 | if training: 45 | images = self.training_augmentations(images) 46 | 47 | # normalize and return 48 | images = self.preprocess(images) 49 | 50 | return images 51 | 52 | @torch.no_grad() 53 | def preprocess_visualization(self, images: torch.Tensor): 54 | """ 55 | Process images by de-normalizing back to range 0_1 56 | :param images: (B C H W) output of the autoencoder in range (-1., 1.) 57 | :return denormalized images in range 0__1 58 | """ 59 | images = self.postprocess(images) 60 | images = torch.clip(images, 0, 1) # if mean,std are correct, should have no effect 61 | return images 62 | 63 | @abstractmethod 64 | def get_tokens(self, images: torch.Tensor) -> torch.IntTensor: 65 | """ 66 | :param images: B, 3, H, W 67 | :return B, S batch of codebook indices 68 | """ 69 | pass 70 | 71 | @abstractmethod 72 | def quantize(self, images: torch.Tensor) -> torch.Tensor: 73 | """ 74 | :param images: B, 3, H, W 75 | :return B, S, D batch of quantized 76 | """ 77 | pass 78 | 79 | @abstractmethod 80 | def reconstruct(self, images: torch.Tensor) -> torch.Tensor: 81 | """ 82 | :param images: B, 3, H, W 83 | :return reconstructions (B, 3, H, W) 84 | """ 85 | pass 86 | 87 | @abstractmethod 88 | def reconstruct_from_tokens(self, tokens: torch.IntTensor) -> torch.Tensor: 89 | """ 90 | :param tokens: B, S where S is the sequence len 91 | :return (B, 3, H, W) reconstructed images 92 | """ 93 | pass 94 | -------------------------------------------------------------------------------- /vqvae/modules/abstract_modules/base_quantizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from abc import ABC, abstractmethod 4 | 5 | 6 | class BaseVectorQuantizer(ABC, nn.Module): 7 | 8 | def __init__(self, num_embeddings: int, embedding_dim: int): 9 | 10 | """ 11 | :param num_embeddings: size of the latent dictionary (num of embedding vectors). 12 | :param embedding_dim: size of a single tensor in dict. 13 | """ 14 | 15 | super().__init__() 16 | 17 | self.num_embeddings = num_embeddings 18 | self.embedding_dim = embedding_dim 19 | 20 | # create the codebook of the desired size 21 | self.codebook = nn.Embedding(self.num_embeddings, self.embedding_dim) 22 | 23 | # wu/decay init (may never be used) 24 | self.kl_warmup = None 25 | self.temp_decay = None 26 | 27 | def init_codebook(self) -> None: 28 | """ 29 | uniform initialization of the codebook 30 | """ 31 | nn.init.uniform_(self.codebook.weight, -1 / self.num_embeddings, 1 / self.num_embeddings) 32 | 33 | @abstractmethod 34 | def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.IntTensor, float): 35 | """ 36 | :param x: tensors (output of the Encoder - B,D,H,W). 37 | :return quantized_x (B, D, H, W), detached codes (B, H*W), latent_loss 38 | """ 39 | pass 40 | 41 | @abstractmethod 42 | def vec_to_codes(self, x: torch.Tensor) -> torch.IntTensor: 43 | """ 44 | :param x: tensors (output of the Encoder - B,D,H,W). 45 | :return flat codebook indices (B, H * W) 46 | """ 47 | pass 48 | 49 | @torch.no_grad() 50 | def get_codebook(self) -> torch.nn.Embedding: 51 | return self.codebook.weight 52 | 53 | @torch.no_grad() 54 | def codes_to_vec(self, codes: torch.IntTensor) -> torch.Tensor: 55 | """ 56 | :param codes: int tensors to decode (B, N). 57 | :return flat codebook indices (B, N, D) 58 | """ 59 | 60 | quantized = self.get_codebook()[codes] 61 | return quantized 62 | 63 | def get_codebook_usage(self, index_count: torch.Tensor): 64 | """ 65 | :param index_count: (n, ) where n is the codebook size, express the number of times each index have been used. 66 | :return: prob of each index to be used: (n, ); perplexity: float; codebook_usage: float 0__1 67 | """ 68 | 69 | # get used idx as probabilities 70 | used_indices = index_count / torch.sum(index_count) 71 | 72 | # perplexity 73 | perplexity = torch.exp(-torch.sum(used_indices * torch.log(used_indices + 1e-10), dim=-1)).sum().item() 74 | 75 | # get the percentage of used codebook 76 | n = index_count.shape[0] 77 | used_codebook = (torch.count_nonzero(used_indices).item() * 100) / n 78 | 79 | return used_indices, perplexity, used_codebook 80 | 81 | @torch.no_grad() 82 | def reinit_unused_codes(self, codebook_usage: torch.Tensor): 83 | """ 84 | Re-initialize unused vectors according to the likelihood of used ones. 85 | :param codebook_usage: (n, ) where n is the codebook size, distribution probability of codebook usage. 86 | """ 87 | 88 | device = codebook_usage.device 89 | n = codebook_usage.shape[0] 90 | 91 | # compute unused codes 92 | unused_codes = torch.nonzero(torch.eq(codebook_usage, torch.zeros(n, device=device))).squeeze(1) 93 | n_unused = unused_codes.shape[0] 94 | 95 | # sample according to most used codes. 96 | torch.use_deterministic_algorithms(False) 97 | replacements = torch.multinomial(codebook_usage, n_unused, replacement=True) 98 | torch.use_deterministic_algorithms(True) 99 | 100 | # update unused codes 101 | new_codes = self.codebook.weight[replacements] 102 | self.codebook.weight[unused_codes] = new_codes 103 | -------------------------------------------------------------------------------- /vqvae/modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | from torch.nn import functional 5 | 6 | 7 | class GroupNorm(nn.Module): 8 | def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-6): 9 | """ 10 | We use a custom implementation for GroupNorm, since torch.nn.GroupNorm occasionally throws NaN values 11 | during Decoding (specifically, the problem seems to be caused by the very last GroupNorm). 12 | The cause of these NaNs is unknown, but the custom implementation so far has never produced them. 13 | """ 14 | super().__init__() 15 | 16 | if num_channels % num_groups != 0: 17 | raise ValueError('num_channels must be divisible by num_groups') 18 | 19 | self.num_groups = num_groups 20 | self.eps = eps 21 | 22 | self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1)) 23 | self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | 27 | b, c, h, w = x.shape 28 | 29 | x = rearrange(x, 'b (g n) h w -> b g (n h w)', g=self.num_groups) 30 | mean = torch.mean(x, dim=2, keepdim=True) 31 | variance = torch.var(x, dim=2, keepdim=True) 32 | 33 | x = (x - mean) / (variance + self.eps).sqrt() 34 | 35 | x = rearrange(x, 'b g (n h w) -> b (g n) h w', h=h, w=w) 36 | 37 | x = x * self.weight + self.bias 38 | 39 | return x 40 | 41 | 42 | class ResBlock(nn.Module): 43 | 44 | def __init__(self, in_channels: int, out_channels: int = None): 45 | """ 46 | :param in_channels: input channels of the residual block 47 | :param out_channels: if None, use in_channels. Else, adds a 1x1 conv layer. 48 | """ 49 | super().__init__() 50 | 51 | if out_channels is None or out_channels == in_channels: 52 | out_channels = in_channels 53 | self.conv_shortcut = None 54 | else: 55 | self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), padding='same', bias=False) 56 | 57 | self.norm1 = GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) 58 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding='same', bias=False) 59 | 60 | self.norm2 = GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6) 61 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding='same', bias=False) 62 | 63 | def forward(self, x): 64 | 65 | residual = functional.silu(self.norm1(x)) 66 | residual = self.conv1(residual) 67 | 68 | residual = functional.silu(self.norm2(residual)) 69 | residual = self.conv2(residual) 70 | 71 | if self.conv_shortcut is not None: 72 | # contiguous prevents warning: 73 | # https://github.com/pytorch/pytorch/issues/47163 74 | # https://discuss.pytorch.org/t/why-does-pytorch-prompt-w-accumulate-grad-h-170-warning-grad-and-param-do-not-obey-the-gradient-layout-contract-this-is-not-an-error-but-may-impair-performance/107760 75 | x = self.conv_shortcut(x.contiguous()) 76 | 77 | return x + residual 78 | 79 | 80 | class Downsample(nn.Module): 81 | 82 | def __init__(self, kernel_size: int = 2, stride: int = 2, padding: int = 0): 83 | super().__init__() 84 | 85 | self.kernel_size = kernel_size 86 | self.stride = stride 87 | self.padding = padding 88 | 89 | def forward(self, x): 90 | res = torch.nn.functional.avg_pool2d(x, self.kernel_size, self.stride, self.padding) 91 | return res 92 | 93 | 94 | class Upsample(nn.Module): 95 | 96 | def __init__(self, channels: int, scale_factor: float = 2.0, mode: str = 'nearest-exact'): 97 | super().__init__() 98 | 99 | self.scale_factor = scale_factor 100 | self.mode = mode 101 | 102 | self.conv = torch.nn.Conv2d(channels, channels, kernel_size=(3, 3), stride=(1, 1), padding='same') 103 | 104 | def forward(self, x): 105 | x = torch.nn.functional.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) 106 | return self.conv(x) 107 | 108 | 109 | class Encoder(nn.Module): 110 | def __init__(self, channels: int, num_res_blocks: int, channel_multipliers: tuple, embedding_dim: int): 111 | 112 | super().__init__() 113 | 114 | self.conv_in = torch.nn.Conv2d(3, channels, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False) 115 | 116 | blocks = [] 117 | ch_in = channels 118 | 119 | for i in range(len(channel_multipliers)): 120 | 121 | ch_out = channels * channel_multipliers[i] 122 | for _ in range(num_res_blocks): 123 | blocks.append(ResBlock(ch_in, ch_out)) 124 | ch_in = ch_out 125 | 126 | blocks.append(Downsample()) 127 | 128 | self.blocks = nn.Sequential(*blocks) 129 | 130 | self.final_residual = nn.Sequential(*[ResBlock(ch_in) for _ in range(num_res_blocks)]) 131 | 132 | self.norm = GroupNorm(num_groups=32, num_channels=ch_in, eps=1e-6) 133 | self.conv_out = torch.nn.Conv2d(ch_in, embedding_dim, kernel_size=(1, 1), padding='same') 134 | 135 | def forward(self, x): 136 | 137 | x = self.conv_in(x) 138 | x = self.blocks(x) 139 | x = self.final_residual(x) 140 | x = self.norm(x) 141 | x = functional.silu(x) 142 | x = self.conv_out(x) 143 | return x 144 | 145 | 146 | class Decoder(nn.Module): 147 | def __init__(self, channels: int, num_res_blocks: int, channel_multipliers: tuple, embedding_dim: int): 148 | 149 | super().__init__() 150 | 151 | ch_in = channels * channel_multipliers[-1] 152 | 153 | self.conv_in = torch.nn.Conv2d(embedding_dim, ch_in, kernel_size=(3, 3), stride=(1, 1), padding=1) 154 | self.initial_residual = nn.Sequential(*[ResBlock(ch_in) for _ in range(num_res_blocks)]) 155 | 156 | blocks = [] 157 | for i in reversed(range(len(channel_multipliers))): 158 | 159 | ch_out = channels * channel_multipliers[i - 1] if i > 0 else channels 160 | 161 | for _ in range(num_res_blocks): 162 | blocks.append(ResBlock(ch_in, ch_out)) 163 | ch_in = ch_out 164 | 165 | blocks.append(Upsample(ch_out)) 166 | 167 | self.blocks = nn.Sequential(*blocks) 168 | 169 | self.norm = GroupNorm(num_groups=32, num_channels=channels, eps=1e-6) 170 | self.conv_out = torch.nn.Conv2d(channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=1) 171 | 172 | def forward(self, x): 173 | x = self.conv_in(x) 174 | x = self.initial_residual(x) 175 | x = self.blocks(x) 176 | x = self.norm(x) 177 | x = functional.silu(x) 178 | x = self.conv_out(x) 179 | x = functional.tanh(x) 180 | return x 181 | -------------------------------------------------------------------------------- /vqvae/modules/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SerezD/vqvae-vqgan-pytorch-lightning/277d909d27e55c36eb410f23684bfdf68f8f38dc/vqvae/modules/loss/__init__.py -------------------------------------------------------------------------------- /vqvae/modules/loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import grad 3 | import torch.nn as nn 4 | import torch.nn.functional as functional 5 | 6 | from vqvae.modules.loss.stylegan2_discriminator.discriminator import Discriminator 7 | from vqvae.modules.loss.lpips_pytorch import LPIPS 8 | from vqvae.modules.loss.stylegan2_discriminator.utils.ops import conv2d_gradfix 9 | 10 | 11 | def generator_loss(logits: torch.Tensor, loss_type: str = "hinge"): 12 | """ 13 | :param logits: discriminator output in the generator phase (fake_logits) 14 | :param loss_type: which loss to apply between 'hinge' and 'non-saturating' 15 | """ 16 | if loss_type == 'hinge': 17 | loss = -torch.mean(logits) 18 | elif loss_type == 'non-saturating': 19 | # Torch docs for bce with logits: 20 | # This loss combines a Sigmoid layer and the BCELoss in one single class. 21 | # This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, 22 | # by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability 23 | loss = functional.binary_cross_entropy_with_logits(logits, target=torch.ones_like(logits)) 24 | else: 25 | raise ValueError(f'unknown loss_type: {loss_type}') 26 | return loss 27 | 28 | 29 | def discriminator_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor, loss_type: str = 'hinge'): 30 | """ 31 | :param logits_real: discriminator output when input is the original image 32 | :param logits_fake: discriminator output when input is the reconstructed image 33 | :param loss_type: which loss to apply between 'hinge' and 'non-saturating' 34 | """ 35 | 36 | if loss_type == 'hinge': 37 | real_loss = functional.relu(1.0 - logits_real) 38 | fake_loss = functional.relu(1.0 + logits_fake) 39 | elif loss_type == 'non-saturating': 40 | # Torch docs for bce with logits: 41 | # This loss combines a Sigmoid layer and the BCELoss in one single class. 42 | # This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, 43 | # by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability 44 | real_loss = functional.binary_cross_entropy_with_logits(logits_real, 45 | target=torch.ones_like(logits_real), reduction='none') 46 | fake_loss = functional.binary_cross_entropy_with_logits(logits_fake, 47 | target=torch.zeros_like(logits_fake), reduction='none') 48 | else: 49 | raise ValueError(f'unknown loss_type: {loss_type}') 50 | 51 | return torch.mean(real_loss + fake_loss) 52 | 53 | 54 | class VQLPIPSWithDiscriminator(nn.Module): 55 | 56 | def __init__(self, image_size: int, l1_weight: float, l2_weight: float, perc_weight: float, adversarial_conf: dict): 57 | 58 | super().__init__() 59 | 60 | self.l1_loss = lambda rec, tar: (tar - rec).abs().mean() 61 | self.l1_weight = l1_weight 62 | 63 | self.l2_loss = lambda rec, tar: (tar - rec).pow(2).mean() 64 | self.l2_weight = l2_weight 65 | 66 | self.perceptual_loss = LPIPS(net_type='vgg') 67 | self.perceptual_weight = perc_weight 68 | 69 | self.discriminator = Discriminator(image_size) 70 | 71 | self.adversarial_start_epoch = adversarial_conf['start_epoch'] 72 | self.adversarial_loss_type = adversarial_conf['loss_type'] 73 | 74 | self.generator_weight = adversarial_conf['g_weight'] 75 | self.use_adaptive_g_weight = adversarial_conf['use_adaptive'] 76 | 77 | self.r1_regularization_cost = adversarial_conf['r1_reg_weight'] 78 | self.r1_regularization_every = adversarial_conf['r1_reg_every'] 79 | 80 | def calculate_adaptive_weight(self, nll_loss: float, g_loss: float, last_layer: torch.nn.Parameter): 81 | """ 82 | From Taming Transformers for High-Resolution Image Synthesis paper, Patrick Esser, Robin Rombach, Bjorn Ommer: 83 | 84 | "we compute the adaptive weight λ according to λ = ∇GL[Lrec] / (∇GL[LGAN] + δ) 85 | where Lrec is the perceptual reconstruction loss, ∇GL[·] denotes the gradient of its input w.r.t. the last 86 | layer L of the decoder, and δ = 10−6 is used for numerical stability" 87 | 88 | """ 89 | nll_grads = grad(nll_loss, last_layer, grad_outputs=torch.ones_like(nll_loss), retain_graph=True)[0].detach() 90 | g_grads = grad(g_loss, last_layer, grad_outputs=torch.ones_like(g_loss), retain_graph=True)[0].detach() 91 | 92 | adaptive_weight = torch.norm(nll_grads, p=2) / (torch.norm(g_grads, p=2) + 1e-8) 93 | adaptive_weight = torch.clamp(adaptive_weight, 0.0, 1e4).detach() 94 | adaptive_weight = adaptive_weight * self.generator_weight 95 | 96 | return adaptive_weight 97 | 98 | def calculate_r1_regularization_term(self, logits_real: torch.Tensor, images: torch.Tensor, compute_r1: bool): 99 | """ 100 | r1 term calculation: https://github.com/rosinality/stylegan2-pytorch/blob/master/train.py 101 | """ 102 | if compute_r1: 103 | 104 | # gradient 105 | with conv2d_gradfix.no_weight_gradients(): 106 | gradients = torch.autograd.grad(outputs=logits_real.sum(), inputs=images, create_graph=True)[0] 107 | 108 | r1_term = self.r1_regularization_cost * gradients.pow(2).view(gradients.shape[0], -1).sum(1).mean() 109 | else: 110 | r1_term = 0. 111 | 112 | return r1_term 113 | 114 | def forward_autoencoder(self, quantizer_loss: float, images: torch.Tensor, reconstructions: torch.Tensor, 115 | current_epoch: int, last_layer: torch.nn.Parameter): 116 | 117 | # reconstruction losses 118 | l1_loss = self.l1_loss(reconstructions.contiguous(), images.contiguous()) 119 | l2_loss = self.l2_loss(reconstructions.contiguous(), images.contiguous()) 120 | p_loss = self.perceptual_loss(images.contiguous(), reconstructions.contiguous()) 121 | 122 | nll_loss = l1_loss * self.l1_weight + l2_loss * self.l2_weight + p_loss * self.perceptual_weight 123 | 124 | # adversarial loss 125 | if current_epoch >= self.adversarial_start_epoch: 126 | 127 | logits_fake = self.discriminator(reconstructions.contiguous()) 128 | g_loss = generator_loss(logits_fake, loss_type=self.adversarial_loss_type) 129 | 130 | if self.training and self.use_adaptive_g_weight: 131 | g_weight = self.calculate_adaptive_weight(p_loss, g_loss, last_layer=last_layer) 132 | 133 | else: 134 | g_weight = self.generator_weight 135 | 136 | loss = nll_loss + g_loss * g_weight + quantizer_loss 137 | else: 138 | g_loss = torch.zeros_like(nll_loss, requires_grad=False) 139 | g_weight = 0. # disc not started yet 140 | loss = nll_loss + quantizer_loss 141 | 142 | return loss, l1_loss, l2_loss, p_loss, g_loss, g_weight 143 | 144 | def forward_discriminator(self, images: torch.Tensor, reconstructions: torch.Tensor, 145 | current_epoch: int, current_step: int): 146 | 147 | if current_epoch >= self.adversarial_start_epoch: 148 | compute_r1 = (self.training and current_step % self.r1_regularization_every == 0 and 149 | self.r1_regularization_cost is not None) 150 | 151 | images = images.contiguous().requires_grad_(compute_r1) 152 | logits_real = self.discriminator(images) 153 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 154 | d_loss = discriminator_loss(logits_real, logits_fake, loss_type=self.adversarial_loss_type) 155 | r1_term = self.calculate_r1_regularization_term(logits_real, images, compute_r1) 156 | loss = d_loss + r1_term 157 | 158 | else: 159 | device = images.device 160 | d_loss = torch.zeros((1,), device=device) 161 | r1_term = 0. 162 | loss = None 163 | 164 | return loss, d_loss, r1_term 165 | 166 | 167 | class VQLPIPS(nn.Module): 168 | 169 | def __init__(self, l1_weight: float, l2_weight: float, perc_weight: float): 170 | """ 171 | VQGAN Loss without discriminator. Used just for ablation. 172 | """ 173 | 174 | super().__init__() 175 | 176 | self.l1_loss = lambda rec, tar: (tar - rec).abs().mean() 177 | self.l1_weight = l1_weight 178 | 179 | self.l2_loss = lambda rec, tar: (tar - rec).pow(2).mean() 180 | self.l2_weight = l2_weight 181 | 182 | self.perceptual_loss = LPIPS(net_type='alex') 183 | self.perceptual_weight = perc_weight 184 | 185 | def forward(self, quantizer_loss: float, images: torch.Tensor, reconstructions: torch.Tensor): 186 | """ 187 | :returns quant + nll loss, l1 loss, l2 loss, perceptual loss 188 | """ 189 | 190 | # reconstruction losses 191 | l1_loss = self.l1_loss(reconstructions.contiguous(), images.contiguous()) 192 | l2_loss = self.l2_loss(reconstructions.contiguous(), images.contiguous()) 193 | p_loss = self.perceptual_loss(images.contiguous(), reconstructions.contiguous()) 194 | 195 | nll_loss = l1_loss * self.l1_weight + l2_loss * self.l2_weight + p_loss * self.perceptual_weight 196 | 197 | loss = quantizer_loss + nll_loss 198 | 199 | return loss, l1_loss, l2_loss, p_loss 200 | -------------------------------------------------------------------------------- /vqvae/modules/loss/lpips_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | # cloned from https://github.com/S-aiueo32/lpips-pytorch 6 | 7 | 8 | def lpips(x: torch.Tensor, 9 | y: torch.Tensor, 10 | net_type: str = 'alex', 11 | version: str = '0.1'): 12 | r"""Function that measures 13 | Learned Perceptual Image Patch Similarity (LPIPS). 14 | 15 | Arguments: 16 | x, y (torch.Tensor): the input tensors to compare. 17 | net_type (str): the network type to compare the features: 18 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 19 | version (str): the version of LPIPS. Default: 0.1. 20 | """ 21 | device = x.device 22 | criterion = LPIPS(net_type, version).to(device) 23 | return criterion(x, y) 24 | -------------------------------------------------------------------------------- /vqvae/modules/loss/lpips_pytorch/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SerezD/vqvae-vqgan-pytorch-lightning/277d909d27e55c36eb410f23684bfdf68f8f38dc/vqvae/modules/loss/lpips_pytorch/modules/__init__.py -------------------------------------------------------------------------------- /vqvae/modules/loss/lpips_pytorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | """ 10 | Creates a criterion that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | net_type (str): the network type to compare the features: 15 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 16 | version (str): the version of LPIPS. Default: 0.1. 17 | """ 18 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 19 | 20 | assert version in ['0.1'], 'v0.1 is only supported now' 21 | 22 | super(LPIPS, self).__init__() 23 | 24 | # pretrained network 25 | self.net = get_network(net_type) 26 | 27 | # linear layers 28 | self.lin = LinLayers(self.net.n_channels_list) 29 | self.lin.load_state_dict(get_state_dict(net_type, version)) 30 | 31 | def forward(self, x: torch.Tensor, y: torch.Tensor): 32 | 33 | feat_x, feat_y = self.net(x), self.net(y) 34 | 35 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 36 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 37 | 38 | return torch.mean(torch.sum(torch.cat(res, 1), 1)) 39 | -------------------------------------------------------------------------------- /vqvae/modules/loss/lpips_pytorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | from torchvision.models import VGG16_Weights 9 | 10 | from .utils import normalize_activation 11 | 12 | 13 | def get_network(net_type: str): 14 | if net_type == 'alex': 15 | return AlexNet() 16 | elif net_type == 'squeeze': 17 | return SqueezeNet() 18 | elif net_type == 'vgg': 19 | return VGG16() 20 | else: 21 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 22 | 23 | 24 | class LinLayers(nn.ModuleList): 25 | def __init__(self, n_channels_list: Sequence[int]): 26 | super(LinLayers, self).__init__([ 27 | nn.Sequential( 28 | nn.Identity(), 29 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 30 | ) for nc in n_channels_list 31 | ]) 32 | 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | 37 | class BaseNet(nn.Module): 38 | def __init__(self): 39 | super(BaseNet, self).__init__() 40 | 41 | # register buffer 42 | self.register_buffer( 43 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 44 | self.register_buffer( 45 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 46 | 47 | def set_requires_grad(self, state: bool): 48 | for param in chain(self.parameters(), self.buffers()): 49 | param.requires_grad = state 50 | 51 | def z_score(self, x: torch.Tensor): 52 | return (x - self.mean) / self.std 53 | 54 | def forward(self, x: torch.Tensor): 55 | x = self.z_score(x) 56 | 57 | output = [] 58 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 59 | x = layer(x) 60 | if i in self.target_layers: 61 | output.append(normalize_activation(x)) 62 | if len(output) == len(self.target_layers): 63 | break 64 | return output 65 | 66 | 67 | class SqueezeNet(BaseNet): 68 | def __init__(self): 69 | super(SqueezeNet, self).__init__() 70 | 71 | self.layers = models.squeezenet1_1(True).features 72 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 73 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 74 | 75 | self.set_requires_grad(False) 76 | 77 | 78 | class AlexNet(BaseNet): 79 | def __init__(self): 80 | super(AlexNet, self).__init__() 81 | 82 | self.layers = models.alexnet(True).features 83 | self.target_layers = [2, 5, 8, 10, 12] 84 | self.n_channels_list = [64, 192, 384, 256, 256] 85 | 86 | self.set_requires_grad(False) 87 | 88 | 89 | class VGG16(BaseNet): 90 | def __init__(self): 91 | super(VGG16, self).__init__() 92 | 93 | self.layers = models.vgg16(weights=VGG16_Weights.DEFAULT).features 94 | self.target_layers = [4, 9, 16, 23, 30] 95 | self.n_channels_list = [64, 128, 256, 512, 512] 96 | 97 | self.set_requires_grad(False) 98 | -------------------------------------------------------------------------------- /vqvae/modules/loss/lpips_pytorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SerezD/vqvae-vqgan-pytorch-lightning/277d909d27e55c36eb410f23684bfdf68f8f38dc/vqvae/modules/loss/stylegan2_discriminator/__init__.py -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SerezD/vqvae-vqgan-pytorch-lightning/277d909d27e55c36eb410f23684bfdf68f8f38dc/vqvae/modules/loss/stylegan2_discriminator/utils/__init__.py -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | # ---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | 26 | # ---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | 43 | # ---------------------------------------------------------------------------- 44 | # Main entry point for compiling and loading C++/CUDA plugins. 45 | 46 | _cached_plugins = dict() 47 | 48 | 49 | def get_plugin(module_name, sources, **build_kwargs): 50 | assert verbosity in ['none', 'brief', 'full'] 51 | 52 | # Already cached? 53 | if module_name in _cached_plugins: 54 | return _cached_plugins[module_name] 55 | 56 | # Print status. 57 | if verbosity == 'full': 58 | print(f'Setting up PyTorch plugin "{module_name}"...') 59 | elif verbosity == 'brief': 60 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 61 | 62 | try: # pylint: disable=too-many-nested-blocks 63 | # Make sure we can find the necessary compiler binaries. 64 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 65 | compiler_bindir = _find_compiler_bindir() 66 | if compiler_bindir is None: 67 | raise RuntimeError( 68 | f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 69 | os.environ['PATH'] += ';' + compiler_bindir 70 | 71 | # Compile and load. 72 | verbose_build = (verbosity == 'full') 73 | 74 | # Incremental build md5sum trickery. Copies all the input source files 75 | # into a cached build directory under a combined md5 digest of the input 76 | # source files. Copying is done only if the combined digest has changed. 77 | # This keeps input file timestamps and filenames the same as in previous 78 | # extension builds, allowing for fast incremental rebuilds. 79 | # 80 | # This optimization is done only in case all the source files reside in 81 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 82 | # environment variable is set (we take this as a signal that the user 83 | # actually cares about this.) 84 | source_dirs_set = set(os.path.dirname(source) for source in sources) 85 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 86 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 87 | 88 | # Compute a combined hash digest for all source files in the same 89 | # custom op directory (usually .cu, .cpp, .py and .h files). 90 | hash_md5 = hashlib.md5() 91 | for src in all_source_files: 92 | with open(src, 'rb') as f: 93 | hash_md5.update(f.read()) 94 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, 95 | verbose=verbose_build) # pylint: disable=protected-access 96 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 97 | 98 | if not os.path.isdir(digest_build_dir): 99 | os.makedirs(digest_build_dir, exist_ok=True) 100 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 101 | if baton.try_acquire(): 102 | try: 103 | for src in all_source_files: 104 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 105 | finally: 106 | baton.release() 107 | else: 108 | # Someone else is copying source files under the digest dir, 109 | # wait until done and continue. 110 | baton.wait() 111 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 112 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 113 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 114 | else: 115 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 116 | module = importlib.import_module(module_name) 117 | 118 | except: 119 | if verbosity == 'brief': 120 | print('Failed!') 121 | raise 122 | 123 | # Print status and add to cache. 124 | if verbosity == 'full': 125 | print(f'Done setting up PyTorch plugin "{module_name}".') 126 | elif verbosity == 'brief': 127 | print('Done.') 128 | _cached_plugins[module_name] = module 129 | return module 130 | 131 | # ---------------------------------------------------------------------------- 132 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Miscellaneous utility classes and functions.""" 10 | 11 | import ctypes 12 | import fnmatch 13 | import importlib 14 | import inspect 15 | import numpy as np 16 | import os 17 | import shutil 18 | import sys 19 | import types 20 | import io 21 | import pickle 22 | import re 23 | import requests 24 | import html 25 | import hashlib 26 | import glob 27 | import tempfile 28 | import urllib 29 | import urllib.request 30 | import uuid 31 | 32 | from distutils.util import strtobool 33 | from typing import Any, List, Tuple, Union 34 | 35 | 36 | # Util classes 37 | # ------------------------------------------------------------------------------------------ 38 | 39 | 40 | class EasyDict(dict): 41 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 42 | 43 | def __getattr__(self, name: str) -> Any: 44 | try: 45 | return self[name] 46 | except KeyError: 47 | raise AttributeError(name) 48 | 49 | def __setattr__(self, name: str, value: Any) -> None: 50 | self[name] = value 51 | 52 | def __delattr__(self, name: str) -> None: 53 | del self[name] 54 | 55 | 56 | class Logger(object): 57 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 58 | 59 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): 60 | self.file = None 61 | 62 | if file_name is not None: 63 | self.file = open(file_name, file_mode) 64 | 65 | self.should_flush = should_flush 66 | self.stdout = sys.stdout 67 | self.stderr = sys.stderr 68 | 69 | sys.stdout = self 70 | sys.stderr = self 71 | 72 | def __enter__(self) -> "Logger": 73 | return self 74 | 75 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 76 | self.close() 77 | 78 | def write(self, text: Union[str, bytes]) -> None: 79 | """Write text to stdout (and a file) and optionally flush.""" 80 | if isinstance(text, bytes): 81 | text = text.decode() 82 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 83 | return 84 | 85 | if self.file is not None: 86 | self.file.write(text) 87 | 88 | self.stdout.write(text) 89 | 90 | if self.should_flush: 91 | self.flush() 92 | 93 | def flush(self) -> None: 94 | """Flush written text to both stdout and a file, if open.""" 95 | if self.file is not None: 96 | self.file.flush() 97 | 98 | self.stdout.flush() 99 | 100 | def close(self) -> None: 101 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 102 | self.flush() 103 | 104 | # if using multiple loggers, prevent closing in wrong order 105 | if sys.stdout is self: 106 | sys.stdout = self.stdout 107 | if sys.stderr is self: 108 | sys.stderr = self.stderr 109 | 110 | if self.file is not None: 111 | self.file.close() 112 | self.file = None 113 | 114 | 115 | # Cache directories 116 | # ------------------------------------------------------------------------------------------ 117 | 118 | _dnnlib_cache_dir = None 119 | 120 | def set_cache_dir(path: str) -> None: 121 | global _dnnlib_cache_dir 122 | _dnnlib_cache_dir = path 123 | 124 | def make_cache_dir_path(*paths: str) -> str: 125 | if _dnnlib_cache_dir is not None: 126 | return os.path.join(_dnnlib_cache_dir, *paths) 127 | if 'DNNLIB_CACHE_DIR' in os.environ: 128 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 129 | if 'HOME' in os.environ: 130 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 131 | if 'USERPROFILE' in os.environ: 132 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 133 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 134 | 135 | # Small util functions 136 | # ------------------------------------------------------------------------------------------ 137 | 138 | 139 | def format_time(seconds: Union[int, float]) -> str: 140 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 141 | s = int(np.rint(seconds)) 142 | 143 | if s < 60: 144 | return "{0}s".format(s) 145 | elif s < 60 * 60: 146 | return "{0}m {1:02}s".format(s // 60, s % 60) 147 | elif s < 24 * 60 * 60: 148 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 149 | else: 150 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 151 | 152 | 153 | def ask_yes_no(question: str) -> bool: 154 | """Ask the user the question until the user inputs a valid answer.""" 155 | while True: 156 | try: 157 | print("{0} [y/n]".format(question)) 158 | return strtobool(input().lower()) 159 | except ValueError: 160 | pass 161 | 162 | 163 | def tuple_product(t: Tuple) -> Any: 164 | """Calculate the product of the tuple elements.""" 165 | result = 1 166 | 167 | for v in t: 168 | result *= v 169 | 170 | return result 171 | 172 | 173 | _str_to_ctype = { 174 | "uint8": ctypes.c_ubyte, 175 | "uint16": ctypes.c_uint16, 176 | "uint32": ctypes.c_uint32, 177 | "uint64": ctypes.c_uint64, 178 | "int8": ctypes.c_byte, 179 | "int16": ctypes.c_int16, 180 | "int32": ctypes.c_int32, 181 | "int64": ctypes.c_int64, 182 | "float32": ctypes.c_float, 183 | "float64": ctypes.c_double 184 | } 185 | 186 | 187 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 188 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 189 | type_str = None 190 | 191 | if isinstance(type_obj, str): 192 | type_str = type_obj 193 | elif hasattr(type_obj, "__name__"): 194 | type_str = type_obj.__name__ 195 | elif hasattr(type_obj, "name"): 196 | type_str = type_obj.name 197 | else: 198 | raise RuntimeError("Cannot infer type name from input") 199 | 200 | assert type_str in _str_to_ctype.keys() 201 | 202 | my_dtype = np.dtype(type_str) 203 | my_ctype = _str_to_ctype[type_str] 204 | 205 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 206 | 207 | return my_dtype, my_ctype 208 | 209 | 210 | def is_pickleable(obj: Any) -> bool: 211 | try: 212 | with io.BytesIO() as stream: 213 | pickle.dump(obj, stream) 214 | return True 215 | except: 216 | return False 217 | 218 | 219 | # Functionality to import modules/objects by name, and call functions by name 220 | # ------------------------------------------------------------------------------------------ 221 | 222 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 223 | """Searches for the underlying module behind the name to some python object. 224 | Returns the module and the object name (original name with module part removed).""" 225 | 226 | # allow convenience shorthands, substitute them by full names 227 | obj_name = re.sub("^np.", "numpy.", obj_name) 228 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 229 | 230 | # list alternatives for (module_name, local_obj_name) 231 | parts = obj_name.split(".") 232 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 233 | 234 | # try each alternative in turn 235 | for module_name, local_obj_name in name_pairs: 236 | try: 237 | module = importlib.import_module(module_name) # may raise ImportError 238 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 239 | return module, local_obj_name 240 | except: 241 | pass 242 | 243 | # maybe some of the modules themselves contain errors? 244 | for module_name, _local_obj_name in name_pairs: 245 | try: 246 | importlib.import_module(module_name) # may raise ImportError 247 | except ImportError: 248 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 249 | raise 250 | 251 | # maybe the requested attribute is missing? 252 | for module_name, local_obj_name in name_pairs: 253 | try: 254 | module = importlib.import_module(module_name) # may raise ImportError 255 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 256 | except ImportError: 257 | pass 258 | 259 | # we are out of luck, but we have no idea why 260 | raise ImportError(obj_name) 261 | 262 | 263 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 264 | """Traverses the object name and returns the last (rightmost) python object.""" 265 | if obj_name == '': 266 | return module 267 | obj = module 268 | for part in obj_name.split("."): 269 | obj = getattr(obj, part) 270 | return obj 271 | 272 | 273 | def get_obj_by_name(name: str) -> Any: 274 | """Finds the python object with the given name.""" 275 | module, obj_name = get_module_from_obj_name(name) 276 | return get_obj_from_module(module, obj_name) 277 | 278 | 279 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 280 | """Finds the python object with the given name and calls it as a function.""" 281 | assert func_name is not None 282 | func_obj = get_obj_by_name(func_name) 283 | assert callable(func_obj) 284 | return func_obj(*args, **kwargs) 285 | 286 | 287 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: 288 | """Finds the python class with the given name and constructs it with the given arguments.""" 289 | return call_func_by_name(*args, func_name=class_name, **kwargs) 290 | 291 | 292 | def get_module_dir_by_obj_name(obj_name: str) -> str: 293 | """Get the directory path of the module containing the given object name.""" 294 | module, _ = get_module_from_obj_name(obj_name) 295 | return os.path.dirname(inspect.getfile(module)) 296 | 297 | 298 | def is_top_level_function(obj: Any) -> bool: 299 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 300 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 301 | 302 | 303 | def get_top_level_function_name(obj: Any) -> str: 304 | """Return the fully-qualified name of a top-level function.""" 305 | assert is_top_level_function(obj) 306 | module = obj.__module__ 307 | if module == '__main__': 308 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] 309 | return module + "." + obj.__name__ 310 | 311 | 312 | # File system helpers 313 | # ------------------------------------------------------------------------------------------ 314 | 315 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 316 | """List all files recursively in a given directory while ignoring given file and directory names. 317 | Returns list of tuples containing both absolute and relative paths.""" 318 | assert os.path.isdir(dir_path) 319 | base_name = os.path.basename(os.path.normpath(dir_path)) 320 | 321 | if ignores is None: 322 | ignores = [] 323 | 324 | result = [] 325 | 326 | for root, dirs, files in os.walk(dir_path, topdown=True): 327 | for ignore_ in ignores: 328 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 329 | 330 | # dirs need to be edited in-place 331 | for d in dirs_to_remove: 332 | dirs.remove(d) 333 | 334 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 335 | 336 | absolute_paths = [os.path.join(root, f) for f in files] 337 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 338 | 339 | if add_base_to_relative: 340 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 341 | 342 | assert len(absolute_paths) == len(relative_paths) 343 | result += zip(absolute_paths, relative_paths) 344 | 345 | return result 346 | 347 | 348 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 349 | """Takes in a list of tuples of (src, dst) paths and copies files. 350 | Will create all necessary directories.""" 351 | for file in files: 352 | target_dir_name = os.path.dirname(file[1]) 353 | 354 | # will create all intermediate-level directories 355 | if not os.path.exists(target_dir_name): 356 | os.makedirs(target_dir_name) 357 | 358 | shutil.copyfile(file[0], file[1]) 359 | 360 | 361 | # URL helpers 362 | # ------------------------------------------------------------------------------------------ 363 | 364 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 365 | """Determine whether the given object is a valid URL string.""" 366 | if not isinstance(obj, str) or not "://" in obj: 367 | return False 368 | if allow_file_urls and obj.startswith('file://'): 369 | return True 370 | try: 371 | res = requests.compat.urlparse(obj) 372 | if not res.scheme or not res.netloc or not "." in res.netloc: 373 | return False 374 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 375 | if not res.scheme or not res.netloc or not "." in res.netloc: 376 | return False 377 | except: 378 | return False 379 | return True 380 | 381 | 382 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: 383 | """Download the given URL and return a binary-mode file object to access the data.""" 384 | assert num_attempts >= 1 385 | assert not (return_filename and (not cache)) 386 | 387 | # Doesn't look like an URL scheme so interpret it as a local filename. 388 | if not re.match('^[a-z]+://', url): 389 | return url if return_filename else open(url, "rb") 390 | 391 | # Handle file URLs. This code handles unusual file:// patterns that 392 | # arise on Windows: 393 | # 394 | # file:///c:/foo.txt 395 | # 396 | # which would translate to a local '/c:/foo.txt' filename that's 397 | # invalid. Drop the forward slash for such pathnames. 398 | # 399 | # If you touch this code path, you should test it on both Linux and 400 | # Windows. 401 | # 402 | # Some internet resources suggest using urllib.request.url2pathname() but 403 | # but that converts forward slashes to backslashes and this causes 404 | # its own set of problems. 405 | if url.startswith('file://'): 406 | filename = urllib.parse.urlparse(url).path 407 | if re.match(r'^/[a-zA-Z]:', filename): 408 | filename = filename[1:] 409 | return filename if return_filename else open(filename, "rb") 410 | 411 | assert is_url(url) 412 | 413 | # Lookup from cache. 414 | if cache_dir is None: 415 | cache_dir = make_cache_dir_path('downloads') 416 | 417 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 418 | if cache: 419 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 420 | if len(cache_files) == 1: 421 | filename = cache_files[0] 422 | return filename if return_filename else open(filename, "rb") 423 | 424 | # Download. 425 | url_name = None 426 | url_data = None 427 | with requests.Session() as session: 428 | if verbose: 429 | print("Downloading %s ..." % url, end="", flush=True) 430 | for attempts_left in reversed(range(num_attempts)): 431 | try: 432 | with session.get(url) as res: 433 | res.raise_for_status() 434 | if len(res.content) == 0: 435 | raise IOError("No data received") 436 | 437 | if len(res.content) < 8192: 438 | content_str = res.content.decode("utf-8") 439 | if "download_warning" in res.headers.get("Set-Cookie", ""): 440 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 441 | if len(links) == 1: 442 | url = requests.compat.urljoin(url, links[0]) 443 | raise IOError("Google Drive virus checker nag") 444 | if "Google Drive - Quota exceeded" in content_str: 445 | raise IOError("Google Drive download quota exceeded -- please try again later") 446 | 447 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 448 | url_name = match[1] if match else url 449 | url_data = res.content 450 | if verbose: 451 | print(" done") 452 | break 453 | except KeyboardInterrupt: 454 | raise 455 | except: 456 | if not attempts_left: 457 | if verbose: 458 | print(" failed") 459 | raise 460 | if verbose: 461 | print(".", end="", flush=True) 462 | 463 | # Save to cache. 464 | if cache: 465 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 466 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 467 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 468 | os.makedirs(cache_dir, exist_ok=True) 469 | with open(temp_file, "wb") as f: 470 | f.write(url_data) 471 | os.replace(temp_file, cache_file) # atomic 472 | if return_filename: 473 | return cache_file 474 | 475 | # Return data as file object. 476 | assert not return_filename 477 | return io.BytesIO(url_data) -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | from vqvae.modules.loss.stylegan2_discriminator.utils import dnnlib 15 | 16 | # ---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | 23 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 24 | value = np.asarray(value) 25 | if shape is not None: 26 | shape = tuple(shape) 27 | if dtype is None: 28 | dtype = torch.get_default_dtype() 29 | if device is None: 30 | device = torch.device('cpu') 31 | if memory_format is None: 32 | memory_format = torch.contiguous_format 33 | 34 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 35 | tensor = _constant_cache.get(key, None) 36 | if tensor is None: 37 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 38 | if shape is not None: 39 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 40 | tensor = tensor.contiguous(memory_format=memory_format) 41 | _constant_cache[key] = tensor 42 | return tensor 43 | 44 | 45 | # ---------------------------------------------------------------------------- 46 | # Replace NaN/Inf with specified numerical values. 47 | 48 | try: 49 | nan_to_num = torch.nan_to_num # 1.8.0a0 50 | except AttributeError: 51 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 52 | assert isinstance(input, torch.Tensor) 53 | if posinf is None: 54 | posinf = torch.finfo(input.dtype).max 55 | if neginf is None: 56 | neginf = torch.finfo(input.dtype).min 57 | assert nan == 0 58 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 59 | 60 | # ---------------------------------------------------------------------------- 61 | # Symbolic assert. 62 | 63 | try: 64 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 65 | except AttributeError: 66 | symbolic_assert = torch.Assert # 1.7.0 67 | 68 | 69 | # ---------------------------------------------------------------------------- 70 | # Context manager to suppress known warnings in torch.jit.trace(). 71 | 72 | class suppress_tracer_warnings(warnings.catch_warnings): 73 | def __enter__(self): 74 | super().__enter__() 75 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning) 76 | return self 77 | 78 | 79 | # ---------------------------------------------------------------------------- 80 | # Assert that the shape of a tensor matches the given list of integers. 81 | # None indicates that the size of a dimension is allowed to vary. 82 | # Performs symbolic assertion when used in torch.jit.trace(). 83 | 84 | def assert_shape(tensor, ref_shape): 85 | if tensor.ndim != len(ref_shape): 86 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 87 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 88 | if ref_size is None: 89 | pass 90 | elif isinstance(ref_size, torch.Tensor): 91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 92 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 93 | elif isinstance(size, torch.Tensor): 94 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 95 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), 96 | f'Wrong size for dimension {idx}: expected {ref_size}') 97 | elif size != ref_size: 98 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 99 | 100 | 101 | # ---------------------------------------------------------------------------- 102 | # Function decorator that calls torch.autograd.profiler.record_function(). 103 | 104 | def profiled_function(fn): 105 | def decorator(*args, **kwargs): 106 | with torch.autograd.profiler.record_function(fn.__name__): 107 | return fn(*args, **kwargs) 108 | 109 | decorator.__name__ = fn.__name__ 110 | return decorator 111 | 112 | 113 | # ---------------------------------------------------------------------------- 114 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 115 | # indefinitely, shuffling items as it goes. 116 | 117 | class InfiniteSampler(torch.utils.data.Sampler): 118 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 119 | assert len(dataset) > 0 120 | assert num_replicas > 0 121 | assert 0 <= rank < num_replicas 122 | assert 0 <= window_size <= 1 123 | super().__init__(dataset) 124 | self.dataset = dataset 125 | self.rank = rank 126 | self.num_replicas = num_replicas 127 | self.shuffle = shuffle 128 | self.seed = seed 129 | self.window_size = window_size 130 | 131 | def __iter__(self): 132 | order = np.arange(len(self.dataset)) 133 | rnd = None 134 | window = 0 135 | if self.shuffle: 136 | rnd = np.random.RandomState(self.seed) 137 | rnd.shuffle(order) 138 | window = int(np.rint(order.size * self.window_size)) 139 | 140 | idx = 0 141 | while True: 142 | i = idx % order.size 143 | if idx % self.num_replicas == self.rank: 144 | yield order[i] 145 | if window >= 2: 146 | j = (i - rnd.randint(window)) % order.size 147 | order[i], order[j] = order[j], order[i] 148 | idx += 1 149 | 150 | 151 | # ---------------------------------------------------------------------------- 152 | # Utilities for operating with torch.nn.Module parameters and buffers. 153 | 154 | def params_and_buffers(module): 155 | assert isinstance(module, torch.nn.Module) 156 | return list(module.parameters()) + list(module.buffers()) 157 | 158 | 159 | def named_params_and_buffers(module): 160 | assert isinstance(module, torch.nn.Module) 161 | return list(module.named_parameters()) + list(module.named_buffers()) 162 | 163 | 164 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 165 | assert isinstance(src_module, torch.nn.Module) 166 | assert isinstance(dst_module, torch.nn.Module) 167 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} 168 | for name, tensor in named_params_and_buffers(dst_module): 169 | assert (name in src_tensors) or (not require_all) 170 | if name in src_tensors: 171 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 172 | 173 | 174 | # ---------------------------------------------------------------------------- 175 | # Context manager for easily enabling/disabling DistributedDataParallel 176 | # synchronization. 177 | 178 | @contextlib.contextmanager 179 | def ddp_sync(module, sync): 180 | assert isinstance(module, torch.nn.Module) 181 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 182 | yield 183 | else: 184 | with module.no_sync(): 185 | yield 186 | 187 | 188 | # ---------------------------------------------------------------------------- 189 | # Check DistributedDataParallel consistency across processes. 190 | 191 | def check_ddp_consistency(module, ignore_regex=None): 192 | assert isinstance(module, torch.nn.Module) 193 | for name, tensor in named_params_and_buffers(module): 194 | fullname = type(module).__name__ + '.' + name 195 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 196 | continue 197 | tensor = tensor.detach() 198 | other = tensor.clone() 199 | torch.distributed.broadcast(tensor=other, src=0) 200 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname 201 | 202 | 203 | # ---------------------------------------------------------------------------- 204 | # Print summary table of module hierarchy. 205 | 206 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 207 | assert isinstance(module, torch.nn.Module) 208 | assert not isinstance(module, torch.jit.ScriptModule) 209 | assert isinstance(inputs, (tuple, list)) 210 | 211 | # Register hooks. 212 | entries = [] 213 | nesting = [0] 214 | 215 | def pre_hook(_mod, _inputs): 216 | nesting[0] += 1 217 | 218 | def post_hook(mod, _inputs, outputs): 219 | nesting[0] -= 1 220 | if nesting[0] <= max_nesting: 221 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 222 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 223 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 224 | 225 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 226 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 227 | 228 | # Run module. 229 | outputs = module(*inputs) 230 | for hook in hooks: 231 | hook.remove() 232 | 233 | # Identify unique outputs, parameters, and buffers. 234 | tensors_seen = set() 235 | for e in entries: 236 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 237 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 238 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 239 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 240 | 241 | # Filter out redundant entries. 242 | if skip_redundant: 243 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 244 | 245 | # Construct table. 246 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 247 | rows += [['---'] * len(rows[0])] 248 | param_total = 0 249 | buffer_total = 0 250 | submodule_names = {mod: name for name, mod in module.named_modules()} 251 | for e in entries: 252 | name = '' if e.mod is module else submodule_names[e.mod] 253 | param_size = sum(t.numel() for t in e.unique_params) 254 | buffer_size = sum(t.numel() for t in e.unique_buffers) 255 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] 256 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 257 | rows += [[ 258 | name + (':0' if len(e.outputs) >= 2 else ''), 259 | str(param_size) if param_size else '-', 260 | str(buffer_size) if buffer_size else '-', 261 | (output_shapes + ['-'])[0], 262 | (output_dtypes + ['-'])[0], 263 | ]] 264 | for idx in range(1, len(e.outputs)): 265 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 266 | param_total += param_size 267 | buffer_total += buffer_size 268 | rows += [['---'] * len(rows[0])] 269 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 270 | 271 | # Print table. 272 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 273 | print() 274 | for row in rows: 275 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 276 | print() 277 | return outputs 278 | 279 | # ---------------------------------------------------------------------------- 280 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import warnings 13 | import numpy as np 14 | import torch 15 | from vqvae.modules.loss.stylegan2_discriminator.utils import dnnlib 16 | import traceback 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | activation_funcs = { 24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 33 | } 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | _inited = False 38 | _plugin = None 39 | _null_tensor = torch.empty([0]) 40 | 41 | def _init(): 42 | global _inited, _plugin 43 | if not _inited: 44 | _inited = True 45 | sources = ['bias_act.cpp', 'bias_act.cu'] 46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 47 | try: 48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 49 | except: 50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 51 | return _plugin is not None 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 56 | r"""Fused bias and activation function. 57 | 58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 59 | and scales the result by `gain`. Each of the steps is optional. In most cases, 60 | the fused op is considerably more efficient than performing the same calculation 61 | using standard PyTorch ops. It supports first and second order gradients, 62 | but not third order gradients. 63 | 64 | Args: 65 | x: Input activation tensor. Can be of any shape. 66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 67 | as `x`. The shape must be known, and it must match the dimension of `x` 68 | corresponding to `dim`. 69 | dim: The dimension in `x` corresponding to the elements of `b`. 70 | The value of `dim` is ignored if `b` is not specified. 71 | act: Name of the activation function to evaluate, or `"linear"` to disable. 72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 73 | See `activation_funcs` for a full list. `None` is not allowed. 74 | alpha: Shape parameter for the activation function, or `None` to use the default. 75 | gain: Scaling factor for the output tensor, or `None` to use default. 76 | See `activation_funcs` for the default scaling of each activation function. 77 | If unsure, consider specifying 1. 78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 79 | the clamping (default). 80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 81 | 82 | Returns: 83 | Tensor of the same shape and datatype as `x`. 84 | """ 85 | assert isinstance(x, torch.Tensor) 86 | assert impl in ['ref', 'cuda'] 87 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | @misc.profiled_function 94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 96 | """ 97 | assert isinstance(x, torch.Tensor) 98 | assert clamp is None or clamp >= 0 99 | spec = activation_funcs[act] 100 | alpha = float(alpha if alpha is not None else spec.def_alpha) 101 | gain = float(gain if gain is not None else spec.def_gain) 102 | clamp = float(clamp if clamp is not None else -1) 103 | 104 | # Add bias. 105 | if b is not None: 106 | assert isinstance(b, torch.Tensor) and b.ndim == 1 107 | assert 0 <= dim < x.ndim 108 | assert b.shape[0] == x.shape[dim] 109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 110 | 111 | # Evaluate activation function. 112 | alpha = float(alpha) 113 | x = spec.func(x, alpha=alpha) 114 | 115 | # Scale by gain. 116 | gain = float(gain) 117 | if gain != 1: 118 | x = x * gain 119 | 120 | # Clamp. 121 | if clamp >= 0: 122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 123 | return x 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | _bias_act_cuda_cache = dict() 128 | 129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 130 | """Fast CUDA implementation of `bias_act()` using custom ops. 131 | """ 132 | # Parse arguments. 133 | assert clamp is None or clamp >= 0 134 | spec = activation_funcs[act] 135 | alpha = float(alpha if alpha is not None else spec.def_alpha) 136 | gain = float(gain if gain is not None else spec.def_gain) 137 | clamp = float(clamp if clamp is not None else -1) 138 | 139 | # Lookup from cache. 140 | key = (dim, act, alpha, gain, clamp) 141 | if key in _bias_act_cuda_cache: 142 | return _bias_act_cuda_cache[key] 143 | 144 | # Forward op. 145 | class BiasActCuda(torch.autograd.Function): 146 | @staticmethod 147 | def forward(ctx, x, b): # pylint: disable=arguments-differ 148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 149 | x = x.contiguous(memory_format=ctx.memory_format) 150 | b = b.contiguous() if b is not None else _null_tensor 151 | y = x 152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 154 | ctx.save_for_backward( 155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 157 | y if 'y' in spec.ref else _null_tensor) 158 | return y 159 | 160 | @staticmethod 161 | def backward(ctx, dy): # pylint: disable=arguments-differ 162 | dy = dy.contiguous(memory_format=ctx.memory_format) 163 | x, b, y = ctx.saved_tensors 164 | dx = None 165 | db = None 166 | 167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 168 | dx = dy 169 | if act != 'linear' or gain != 1 or clamp >= 0: 170 | dx = BiasActCudaGrad.apply(dy, x, b, y) 171 | 172 | if ctx.needs_input_grad[1]: 173 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 174 | 175 | return dx, db 176 | 177 | # Backward op. 178 | class BiasActCudaGrad(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 183 | ctx.save_for_backward( 184 | dy if spec.has_2nd_grad else _null_tensor, 185 | x, b, y) 186 | return dx 187 | 188 | @staticmethod 189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 191 | dy, x, b, y = ctx.saved_tensors 192 | d_dy = None 193 | d_x = None 194 | d_b = None 195 | d_y = None 196 | 197 | if ctx.needs_input_grad[0]: 198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 199 | 200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 202 | 203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 205 | 206 | return d_dy, d_x, d_b, d_y 207 | 208 | # Add to cache. 209 | _bias_act_cuda_cache[key] = BiasActCuda 210 | return BiasActCuda 211 | 212 | #---------------------------------------------------------------------------- 213 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import warnings 13 | import contextlib 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | 25 | @contextlib.contextmanager 26 | def no_weight_gradients(): 27 | global weight_gradients_disabled 28 | old = weight_gradients_disabled 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 54 | return True 55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 56 | return False 57 | 58 | def _tuple_of_ints(xs, ndim): 59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 60 | assert len(xs) == ndim 61 | assert all(isinstance(x, int) for x in xs) 62 | return xs 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | _conv2d_gradfix_cache = dict() 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | if not transpose: 112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 113 | else: # transpose 114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 115 | ctx.save_for_backward(input, weight) 116 | return output 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | input, weight = ctx.saved_tensors 121 | grad_input = None 122 | grad_weight = None 123 | grad_bias = None 124 | 125 | if ctx.needs_input_grad[0]: 126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 128 | assert grad_input.shape == input.shape 129 | 130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 131 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 132 | assert grad_weight.shape == weight_shape 133 | 134 | if ctx.needs_input_grad[2]: 135 | grad_bias = grad_output.sum([0, 2, 3]) 136 | 137 | return grad_input, grad_weight, grad_bias 138 | 139 | # Gradient with respect to the weights. 140 | class Conv2dGradWeight(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, grad_output, input): 143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 146 | assert grad_weight.shape == weight_shape 147 | ctx.save_for_backward(grad_output, input) 148 | return grad_weight 149 | 150 | @staticmethod 151 | def backward(ctx, grad2_grad_weight): 152 | grad_output, input = ctx.saved_tensors 153 | grad2_grad_output = None 154 | grad2_input = None 155 | 156 | if ctx.needs_input_grad[0]: 157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 158 | assert grad2_grad_output.shape == grad_output.shape 159 | 160 | if ctx.needs_input_grad[1]: 161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 163 | assert grad2_input.shape == input.shape 164 | 165 | return grad2_grad_output, grad2_input 166 | 167 | _conv2d_gradfix_cache[key] = Conv2d 168 | return Conv2d 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | w = w.flip([2, 3]) 37 | 38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 42 | if out_channels <= 4 and groups == 1: 43 | in_shape = x.shape 44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 46 | else: 47 | x = x.to(memory_format=torch.contiguous_format) 48 | w = w.to(memory_format=torch.contiguous_format) 49 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 50 | return x.to(memory_format=torch.channels_last) 51 | 52 | # Otherwise => execute using conv2d_gradfix. 53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 54 | return op(x, w, stride=stride, padding=padding, groups=groups) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | @misc.profiled_function 59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 60 | r"""2D convolution with optional up/downsampling. 61 | 62 | Padding is performed only once at the beginning, not between the operations. 63 | 64 | Args: 65 | x: Input tensor of shape 66 | `[batch_size, in_channels, in_height, in_width]`. 67 | w: Weight tensor of shape 68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 70 | calling upfirdn2d.setup_filter(). None = identity (default). 71 | up: Integer upsampling factor (default: 1). 72 | down: Integer downsampling factor (default: 1). 73 | padding: Padding with respect to the upsampled image. Can be a single number 74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 75 | (default: 0). 76 | groups: Split input channels into N groups (default: 1). 77 | flip_weight: False = convolution, True = correlation (default: True). 78 | flip_filter: False = convolution, True = correlation (default: False). 79 | 80 | Returns: 81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 82 | """ 83 | # Validate arguments. 84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 87 | assert isinstance(up, int) and (up >= 1) 88 | assert isinstance(down, int) and (down >= 1) 89 | assert isinstance(groups, int) and (groups >= 1) 90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 91 | fw, fh = _get_filter_size(f) 92 | px0, px1, py0, py1 = _parse_padding(padding) 93 | 94 | # Adjust padding to account for up/downsampling. 95 | if up > 1: 96 | px0 += (fw + up - 1) // 2 97 | px1 += (fw - up) // 2 98 | py0 += (fh + up - 1) // 2 99 | py1 += (fh - up) // 2 100 | if down > 1: 101 | px0 += (fw - down + 1) // 2 102 | px1 += (fw - down) // 2 103 | py0 += (fh - down + 1) // 2 104 | py1 += (fh - down) // 2 105 | 106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 107 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 110 | return x 111 | 112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 113 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 116 | return x 117 | 118 | # Fast path: downsampling only => use strided convolution. 119 | if down > 1 and up == 1: 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 122 | return x 123 | 124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 125 | if up > 1: 126 | if groups == 1: 127 | w = w.transpose(0, 1) 128 | else: 129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 130 | w = w.transpose(1, 2) 131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 132 | px0 -= kw - 1 133 | px1 -= kw - up 134 | py0 -= kh - 1 135 | py1 -= kh - up 136 | pxt = max(min(-px0, -px1), 0) 137 | pyt = max(min(-py0, -py1), 0) 138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 140 | if down > 1: 141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 142 | return x 143 | 144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 145 | if up == 1 and down == 1: 146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 148 | 149 | # Fallback: Generic reference implementation. 150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 152 | if down > 1: 153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 154 | return x 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def grid_sample(input, grid): 28 | if _should_use_custom_op(): 29 | return _GridSample2dForward.apply(input, grid) 30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _should_use_custom_op(): 35 | if not enabled: 36 | return False 37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 38 | return True 39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 40 | return False 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | class _GridSample2dForward(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input, grid): 47 | assert input.ndim == 4 48 | assert grid.ndim == 4 49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 50 | ctx.save_for_backward(input, grid) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, grid = ctx.saved_tensors 56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 57 | return grad_input, grad_grid 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class _GridSample2dBackward(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, grad_output, input, grid): 64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/ops/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient resampling of 2D images.""" 10 | 11 | import os 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import traceback 16 | 17 | from .. import custom_ops 18 | from .. import misc 19 | from . import conv2d_gradfix 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _inited = False 24 | _plugin = None 25 | 26 | def _init(): 27 | global _inited, _plugin 28 | if not _inited: 29 | sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] 30 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 31 | try: 32 | _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 33 | except: 34 | warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 35 | return _plugin is not None 36 | 37 | def _parse_scaling(scaling): 38 | if isinstance(scaling, int): 39 | scaling = [scaling, scaling] 40 | assert isinstance(scaling, (list, tuple)) 41 | assert all(isinstance(x, int) for x in scaling) 42 | sx, sy = scaling 43 | assert sx >= 1 and sy >= 1 44 | return sx, sy 45 | 46 | def _parse_padding(padding): 47 | if isinstance(padding, int): 48 | padding = [padding, padding] 49 | assert isinstance(padding, (list, tuple)) 50 | assert all(isinstance(x, int) for x in padding) 51 | if len(padding) == 2: 52 | padx, pady = padding 53 | padding = [padx, padx, pady, pady] 54 | padx0, padx1, pady0, pady1 = padding 55 | return padx0, padx1, pady0, pady1 56 | 57 | def _get_filter_size(f): 58 | if f is None: 59 | return 1, 1 60 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 61 | fw = f.shape[-1] 62 | fh = f.shape[0] 63 | with misc.suppress_tracer_warnings(): 64 | fw = int(fw) 65 | fh = int(fh) 66 | misc.assert_shape(f, [fh, fw][:f.ndim]) 67 | assert fw >= 1 and fh >= 1 68 | return fw, fh 69 | 70 | #---------------------------------------------------------------------------- 71 | 72 | def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): 73 | r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. 74 | 75 | Args: 76 | f: Torch tensor, numpy array, or python list of the shape 77 | `[filter_height, filter_width]` (non-separable), 78 | `[filter_taps]` (separable), 79 | `[]` (impulse), or 80 | `None` (identity). 81 | device: Result device (default: cpu). 82 | normalize: Normalize the filter so that it retains the magnitude 83 | for constant input signal (DC)? (default: True). 84 | flip_filter: Flip the filter? (default: False). 85 | gain: Overall scaling factor for signal magnitude (default: 1). 86 | separable: Return a separable filter? (default: select automatically). 87 | 88 | Returns: 89 | Float32 tensor of the shape 90 | `[filter_height, filter_width]` (non-separable) or 91 | `[filter_taps]` (separable). 92 | """ 93 | # Validate. 94 | if f is None: 95 | f = 1 96 | f = torch.as_tensor(f, dtype=torch.float32) 97 | assert f.ndim in [0, 1, 2] 98 | assert f.numel() > 0 99 | if f.ndim == 0: 100 | f = f[np.newaxis] 101 | 102 | # Separable? 103 | if separable is None: 104 | separable = (f.ndim == 1 and f.numel() >= 8) 105 | if f.ndim == 1 and not separable: 106 | f = f.ger(f) 107 | assert f.ndim == (1 if separable else 2) 108 | 109 | # Apply normalize, flip, gain, and device. 110 | if normalize: 111 | f /= f.sum() 112 | if flip_filter: 113 | f = f.flip(list(range(f.ndim))) 114 | f = f * (gain ** (f.ndim / 2)) 115 | f = f.to(device=device) 116 | return f 117 | 118 | #---------------------------------------------------------------------------- 119 | 120 | def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): 121 | r"""Pad, upsample, filter, and downsample a batch of 2D images. 122 | 123 | Performs the following sequence of operations for each channel: 124 | 125 | 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). 126 | 127 | 2. Pad the image with the specified number of zeros on each side (`padding`). 128 | Negative padding corresponds to cropping the image. 129 | 130 | 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it 131 | so that the footprint of all output pixels lies within the input image. 132 | 133 | 4. Downsample the image by keeping every Nth pixel (`down`). 134 | 135 | This sequence of operations bears close resemblance to scipy.signal.upfirdn(). 136 | The fused op is considerably more efficient than performing the same calculation 137 | using standard PyTorch ops. It supports gradients of arbitrary order. 138 | 139 | Args: 140 | x: Float32/float64/float16 input tensor of the shape 141 | `[batch_size, num_channels, in_height, in_width]`. 142 | f: Float32 FIR filter of the shape 143 | `[filter_height, filter_width]` (non-separable), 144 | `[filter_taps]` (separable), or 145 | `None` (identity). 146 | up: Integer upsampling factor. Can be a single int or a list/tuple 147 | `[x, y]` (default: 1). 148 | down: Integer downsampling factor. Can be a single int or a list/tuple 149 | `[x, y]` (default: 1). 150 | padding: Padding with respect to the upsampled image. Can be a single number 151 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 152 | (default: 0). 153 | flip_filter: False = convolution, True = correlation (default: False). 154 | gain: Overall scaling factor for signal magnitude (default: 1). 155 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 156 | 157 | Returns: 158 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 159 | """ 160 | assert isinstance(x, torch.Tensor) 161 | assert impl in ['ref', 'cuda'] 162 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 163 | return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) 164 | return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) 165 | 166 | #---------------------------------------------------------------------------- 167 | 168 | @misc.profiled_function 169 | def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): 170 | """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. 171 | """ 172 | # Validate arguments. 173 | assert isinstance(x, torch.Tensor) and x.ndim == 4 174 | if f is None: 175 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 176 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 177 | assert f.dtype == torch.float32 and not f.requires_grad 178 | batch_size, num_channels, in_height, in_width = x.shape 179 | upx, upy = _parse_scaling(up) 180 | downx, downy = _parse_scaling(down) 181 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 182 | 183 | # Upsample by inserting zeros. 184 | x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) 185 | x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) 186 | x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) 187 | 188 | # Pad or crop. 189 | x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) 190 | x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] 191 | 192 | # Setup filter. 193 | f = f * (gain ** (f.ndim / 2)) 194 | f = f.to(x.dtype) 195 | if not flip_filter: 196 | f = f.flip(list(range(f.ndim))) 197 | 198 | # Convolve with the filter. 199 | f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) 200 | if f.ndim == 4: 201 | x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) 202 | else: 203 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) 204 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) 205 | 206 | # Downsample by throwing away pixels. 207 | x = x[:, :, ::downy, ::downx] 208 | return x 209 | 210 | #---------------------------------------------------------------------------- 211 | 212 | _upfirdn2d_cuda_cache = dict() 213 | 214 | def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): 215 | """Fast CUDA implementation of `upfirdn2d()` using custom ops. 216 | """ 217 | # Parse arguments. 218 | upx, upy = _parse_scaling(up) 219 | downx, downy = _parse_scaling(down) 220 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 221 | 222 | # Lookup from cache. 223 | key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) 224 | if key in _upfirdn2d_cuda_cache: 225 | return _upfirdn2d_cuda_cache[key] 226 | 227 | # Forward op. 228 | class Upfirdn2dCuda(torch.autograd.Function): 229 | @staticmethod 230 | def forward(ctx, x, f): # pylint: disable=arguments-differ 231 | assert isinstance(x, torch.Tensor) and x.ndim == 4 232 | if f is None: 233 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 234 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 235 | y = x 236 | if f.ndim == 2: 237 | y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) 238 | else: 239 | y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) 240 | y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) 241 | ctx.save_for_backward(f) 242 | ctx.x_shape = x.shape 243 | return y 244 | 245 | @staticmethod 246 | def backward(ctx, dy): # pylint: disable=arguments-differ 247 | f, = ctx.saved_tensors 248 | _, _, ih, iw = ctx.x_shape 249 | _, _, oh, ow = dy.shape 250 | fw, fh = _get_filter_size(f) 251 | p = [ 252 | fw - padx0 - 1, 253 | iw * upx - ow * downx + padx0 - upx + 1, 254 | fh - pady0 - 1, 255 | ih * upy - oh * downy + pady0 - upy + 1, 256 | ] 257 | dx = None 258 | df = None 259 | 260 | if ctx.needs_input_grad[0]: 261 | dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) 262 | 263 | assert not ctx.needs_input_grad[1] 264 | return dx, df 265 | 266 | # Add to cache. 267 | _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda 268 | return Upfirdn2dCuda 269 | 270 | #---------------------------------------------------------------------------- 271 | 272 | def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): 273 | r"""Filter a batch of 2D images using the given 2D FIR filter. 274 | 275 | By default, the result is padded so that its shape matches the input. 276 | User-specified padding is applied on top of that, with negative values 277 | indicating cropping. Pixels outside the image are assumed to be zero. 278 | 279 | Args: 280 | x: Float32/float64/float16 input tensor of the shape 281 | `[batch_size, num_channels, in_height, in_width]`. 282 | f: Float32 FIR filter of the shape 283 | `[filter_height, filter_width]` (non-separable), 284 | `[filter_taps]` (separable), or 285 | `None` (identity). 286 | padding: Padding with respect to the output. Can be a single number or a 287 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 288 | (default: 0). 289 | flip_filter: False = convolution, True = correlation (default: False). 290 | gain: Overall scaling factor for signal magnitude (default: 1). 291 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 292 | 293 | Returns: 294 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 295 | """ 296 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 297 | fw, fh = _get_filter_size(f) 298 | p = [ 299 | padx0 + fw // 2, 300 | padx1 + (fw - 1) // 2, 301 | pady0 + fh // 2, 302 | pady1 + (fh - 1) // 2, 303 | ] 304 | return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) 305 | 306 | #---------------------------------------------------------------------------- 307 | 308 | def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): 309 | r"""Upsample a batch of 2D images using the given 2D FIR filter. 310 | 311 | By default, the result is padded so that its shape is a multiple of the input. 312 | User-specified padding is applied on top of that, with negative values 313 | indicating cropping. Pixels outside the image are assumed to be zero. 314 | 315 | Args: 316 | x: Float32/float64/float16 input tensor of the shape 317 | `[batch_size, num_channels, in_height, in_width]`. 318 | f: Float32 FIR filter of the shape 319 | `[filter_height, filter_width]` (non-separable), 320 | `[filter_taps]` (separable), or 321 | `None` (identity). 322 | up: Integer upsampling factor. Can be a single int or a list/tuple 323 | `[x, y]` (default: 1). 324 | padding: Padding with respect to the output. Can be a single number or a 325 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 326 | (default: 0). 327 | flip_filter: False = convolution, True = correlation (default: False). 328 | gain: Overall scaling factor for signal magnitude (default: 1). 329 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 330 | 331 | Returns: 332 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 333 | """ 334 | upx, upy = _parse_scaling(up) 335 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 336 | fw, fh = _get_filter_size(f) 337 | p = [ 338 | padx0 + (fw + upx - 1) // 2, 339 | padx1 + (fw - upx) // 2, 340 | pady0 + (fh + upy - 1) // 2, 341 | pady1 + (fh - upy) // 2, 342 | ] 343 | return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) 344 | 345 | #---------------------------------------------------------------------------- 346 | 347 | def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): 348 | r"""Downsample a batch of 2D images using the given 2D FIR filter. 349 | 350 | By default, the result is padded so that its shape is a fraction of the input. 351 | User-specified padding is applied on top of that, with negative values 352 | indicating cropping. Pixels outside the image are assumed to be zero. 353 | 354 | Args: 355 | x: Float32/float64/float16 input tensor of the shape 356 | `[batch_size, num_channels, in_height, in_width]`. 357 | f: Float32 FIR filter of the shape 358 | `[filter_height, filter_width]` (non-separable), 359 | `[filter_taps]` (separable), or 360 | `None` (identity). 361 | down: Integer downsampling factor. Can be a single int or a list/tuple 362 | `[x, y]` (default: 1). 363 | padding: Padding with respect to the input. Can be a single number or a 364 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 365 | (default: 0). 366 | flip_filter: False = convolution, True = correlation (default: False). 367 | gain: Overall scaling factor for signal magnitude (default: 1). 368 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 369 | 370 | Returns: 371 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 372 | """ 373 | downx, downy = _parse_scaling(down) 374 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 375 | fw, fh = _get_filter_size(f) 376 | p = [ 377 | padx0 + (fw - downx + 1) // 2, 378 | padx1 + (fw - downx) // 2, 379 | pady0 + (fh - downy + 1) // 2, 380 | pady1 + (fh - downy) // 2, 381 | ] 382 | return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) 383 | 384 | #---------------------------------------------------------------------------- 385 | -------------------------------------------------------------------------------- /vqvae/modules/loss/stylegan2_discriminator/utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | from vqvae.modules.loss.stylegan2_discriminator.utils import dnnlib 24 | 25 | # ---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | 34 | # ---------------------------------------------------------------------------- 35 | 36 | def persistent_class(orig_class): 37 | r"""Class decorator that extends a given class to save its source code 38 | when pickled. 39 | 40 | Example: 41 | 42 | from torch_utils import persistence 43 | 44 | @persistence.persistent_class 45 | class MyNetwork(torch.nn.Module): 46 | def __init__(self, num_inputs, num_outputs): 47 | super().__init__() 48 | self.fc = MyLayer(num_inputs, num_outputs) 49 | ... 50 | 51 | @persistence.persistent_class 52 | class MyLayer(torch.nn.Module): 53 | ... 54 | 55 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 56 | source code alongside other internal state (e.g., parameters, buffers, 57 | and submodules). This way, any previously exported pickle will remain 58 | usable even if the class definitions have been modified or are no 59 | longer available. 60 | 61 | The decorator saves the source code of the entire Python module 62 | containing the decorated class. It does *not* save the source code of 63 | any imported modules. Thus, the imported modules must be available 64 | during unpickling, also including `torch_utils.persistence` itself. 65 | 66 | It is ok to call functions defined in the same module from the 67 | decorated class. However, if the decorated class depends on other 68 | classes defined in the same module, they must be decorated as well. 69 | This is illustrated in the above example in the case of `MyLayer`. 70 | 71 | It is also possible to employ the decorator just-in-time before 72 | calling the constructor. For example: 73 | 74 | cls = MyLayer 75 | if want_to_make_it_persistent: 76 | cls = persistence.persistent_class(cls) 77 | layer = cls(num_inputs, num_outputs) 78 | 79 | As an additional feature, the decorator also keeps track of the 80 | arguments that were used to construct each instance of the decorated 81 | class. The arguments can be queried via `obj.init_args` and 82 | `obj.init_kwargs`, and they are automatically pickled alongside other 83 | object state. A typical use case is to first unpickle a previous 84 | instance of a persistent class, and then upgrade it to use the latest 85 | version of the source code: 86 | 87 | with open('old_pickle.pkl', 'rb') as f: 88 | old_net = pickle.load(f) 89 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 90 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 91 | """ 92 | assert isinstance(orig_class, type) 93 | if is_persistent(orig_class): 94 | return orig_class 95 | 96 | assert orig_class.__module__ in sys.modules 97 | orig_module = sys.modules[orig_class.__module__] 98 | orig_module_src = _module_to_src(orig_module) 99 | 100 | class Decorator(orig_class): 101 | _orig_module_src = orig_module_src 102 | _orig_class_name = orig_class.__name__ 103 | 104 | def __init__(self, *args, **kwargs): 105 | super().__init__(*args, **kwargs) 106 | self._init_args = copy.deepcopy(args) 107 | self._init_kwargs = copy.deepcopy(kwargs) 108 | assert orig_class.__name__ in orig_module.__dict__ 109 | _check_pickleable(self.__reduce__()) 110 | 111 | @property 112 | def init_args(self): 113 | return copy.deepcopy(self._init_args) 114 | 115 | @property 116 | def init_kwargs(self): 117 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 118 | 119 | def __reduce__(self): 120 | fields = list(super().__reduce__()) 121 | fields += [None] * max(3 - len(fields), 0) 122 | if fields[0] is not _reconstruct_persistent_obj: 123 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, 124 | class_name=self._orig_class_name, state=fields[2]) 125 | fields[0] = _reconstruct_persistent_obj # reconstruct func 126 | fields[1] = (meta,) # reconstruct args 127 | fields[2] = None # state dict 128 | return tuple(fields) 129 | 130 | Decorator.__name__ = orig_class.__name__ 131 | _decorators.add(Decorator) 132 | return Decorator 133 | 134 | 135 | # ---------------------------------------------------------------------------- 136 | 137 | def is_persistent(obj): 138 | r"""Test whether the given object or class is persistent, i.e., 139 | whether it will save its source code when pickled. 140 | """ 141 | try: 142 | if obj in _decorators: 143 | return True 144 | except TypeError: 145 | pass 146 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 147 | 148 | 149 | # ---------------------------------------------------------------------------- 150 | 151 | def import_hook(hook): 152 | r"""Register an import hook that is called whenever a persistent object 153 | is being unpickled. A typical use case is to patch the pickled source 154 | code to avoid errors and inconsistencies when the API of some imported 155 | module has changed. 156 | 157 | The hook should have the following signature: 158 | 159 | hook(meta) -> modified meta 160 | 161 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 162 | 163 | type: Type of the persistent object, e.g. `'class'`. 164 | version: Internal version number of `torch_utils.persistence`. 165 | module_src Original source code of the Python module. 166 | class_name: Class name in the original Python module. 167 | state: Internal state of the object. 168 | 169 | Example: 170 | 171 | @persistence.import_hook 172 | def wreck_my_network(meta): 173 | if meta.class_name == 'MyNetwork': 174 | print('MyNetwork is being imported. I will wreck it!') 175 | meta.module_src = meta.module_src.replace("True", "False") 176 | return meta 177 | """ 178 | assert callable(hook) 179 | _import_hooks.append(hook) 180 | 181 | 182 | # ---------------------------------------------------------------------------- 183 | 184 | def _reconstruct_persistent_obj(meta): 185 | r"""Hook that is called internally by the `pickle` module to unpickle 186 | a persistent object. 187 | """ 188 | meta = dnnlib.EasyDict(meta) 189 | meta.state = dnnlib.EasyDict(meta.state) 190 | for hook in _import_hooks: 191 | meta = hook(meta) 192 | assert meta is not None 193 | 194 | assert meta.version == _version 195 | module = _src_to_module(meta.module_src) 196 | 197 | assert meta.type == 'class' 198 | orig_class = module.__dict__[meta.class_name] 199 | decorator_class = persistent_class(orig_class) 200 | obj = decorator_class.__new__(decorator_class) 201 | 202 | setstate = getattr(obj, '__setstate__', None) 203 | if callable(setstate): 204 | setstate(meta.state) # pylint: disable=not-callable 205 | else: 206 | obj.__dict__.update(meta.state) 207 | return obj 208 | 209 | 210 | # ---------------------------------------------------------------------------- 211 | 212 | def _module_to_src(module): 213 | r"""Query the source code of a given Python module. 214 | """ 215 | src = _module_to_src_dict.get(module, None) 216 | if src is None: 217 | src = inspect.getsource(module) 218 | _module_to_src_dict[module] = src 219 | _src_to_module_dict[src] = module 220 | return src 221 | 222 | 223 | def _src_to_module(src): 224 | r"""Get or create a Python module for the given source code. 225 | """ 226 | module = _src_to_module_dict.get(src, None) 227 | if module is None: 228 | module_name = "_imported_module_" + uuid.uuid4().hex 229 | module = types.ModuleType(module_name) 230 | sys.modules[module_name] = module 231 | _module_to_src_dict[module] = src 232 | _src_to_module_dict[src] = module 233 | exec(src, module.__dict__) # pylint: disable=exec-used 234 | return module 235 | 236 | 237 | # ---------------------------------------------------------------------------- 238 | 239 | def _check_pickleable(obj): 240 | r"""Check that the given object is pickleable, raising an exception if 241 | it is not. This function is expected to be considerably more efficient 242 | than actually pickling the object. 243 | """ 244 | 245 | def recurse(obj): 246 | if isinstance(obj, (list, tuple, set)): 247 | return [recurse(x) for x in obj] 248 | if isinstance(obj, dict): 249 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 250 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 251 | return None # Python primitive types are pickleable. 252 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: 253 | return None # NumPy arrays and PyTorch tensors are pickleable. 254 | if is_persistent(obj): 255 | return None # Persistent objects are pickleable, by virtue of the constructor check. 256 | return obj 257 | 258 | with io.BytesIO() as f: 259 | pickle.dump(recurse(obj), f) 260 | 261 | # ---------------------------------------------------------------------------- 262 | -------------------------------------------------------------------------------- /vqvae/modules/vector_quantizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from einops import rearrange, einsum 5 | from vqvae.modules.abstract_modules.base_quantizer import BaseVectorQuantizer 6 | 7 | 8 | class VectorQuantizer(BaseVectorQuantizer): 9 | 10 | def __init__(self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25): 11 | 12 | """ 13 | Original VectorQuantizer with straight through gradient estimator (loss is optimized on inputs and codebook) 14 | :param num_embeddings: size of the latent dictionary (num of embedding vectors). 15 | :param embedding_dim: size of a single tensor in dict. 16 | :param commitment_cost: scaling factor for e_loss 17 | """ 18 | 19 | super().__init__(num_embeddings, embedding_dim) 20 | 21 | self.commitment_cost = commitment_cost 22 | 23 | def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.IntTensor, float): 24 | """ 25 | :param x: tensors (output of the Encoder - B,D,H,W). 26 | :return quantized_x (B, D, H, W), detached codes (B, H*W), latent_loss 27 | """ 28 | 29 | b, c, h, w = x.shape 30 | device = x.device 31 | 32 | # Flat input to vectors of embedding dim = C. 33 | flat_x = rearrange(x, 'b c h w -> (b h w) c') 34 | 35 | # Calculate distances of each vector w.r.t the dict 36 | # distances is a matrix (B*H*W, codebook_size) 37 | distances = (torch.sum(flat_x ** 2, dim=1, keepdim=True) 38 | + torch.sum(self.codebook.weight ** 2, dim=1) 39 | - 2 * torch.matmul(flat_x, self.codebook.weight.t())) 40 | 41 | # Get indices of the closest vector in dict, and create a mask on the correct indexes 42 | # encoding_indices = (num_vectors_in_batch, 1) 43 | # Mask = (num_vectors_in_batch, codebook_dim) 44 | encoding_indices = torch.argmin(distances, dim=1) 45 | encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=device) 46 | encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) 47 | 48 | # Quantize and un-flat 49 | quantized = torch.matmul(encodings, self.codebook.weight) 50 | 51 | # Loss functions 52 | e_loss = self.commitment_cost * F.mse_loss(quantized.detach(), flat_x) 53 | q_loss = F.mse_loss(quantized, flat_x.detach()) 54 | 55 | # during backpropagation quantized = inputs (copy gradient trick) 56 | quantized = flat_x + (quantized - flat_x).detach() 57 | 58 | quantized = rearrange(quantized, '(b h w) c -> b c h w', b=b, h=h, w=w) 59 | encoding_indices = rearrange(encoding_indices, '(b h w)-> b (h w)', b=b, h=h, w=w).detach() 60 | 61 | return quantized, encoding_indices, q_loss + e_loss 62 | 63 | @torch.no_grad() 64 | def vec_to_codes(self, x: torch.Tensor) -> torch.IntTensor: 65 | """ 66 | :param x: tensors (output of the Encoder - B,D,H,W). 67 | :return flat codebook indices (B, H * W) 68 | """ 69 | b, c, h, w = x.shape 70 | 71 | # Flat input to vectors of embedding dim = C. 72 | flat_x = rearrange(x, 'b c h w -> (b h w) c') 73 | 74 | # Calculate distances of each vector w.r.t the dict 75 | # distances is a matrix (B*H*W, codebook_size) 76 | distances = (torch.sum(flat_x ** 2, dim=1, keepdim=True) 77 | + torch.sum(self.codebook.weight ** 2, dim=1) 78 | - 2 * torch.matmul(flat_x, self.codebook.weight.t())) 79 | 80 | # Get indices of the closest vector in dict 81 | encoding_indices = torch.argmin(distances, dim=1) 82 | encoding_indices = rearrange(encoding_indices, '(b h w) -> b (h w)', b=b, h=h, w=w) 83 | 84 | return encoding_indices 85 | 86 | 87 | class EMAVectorQuantizer(BaseVectorQuantizer): 88 | 89 | def __init__(self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25, decay: float = 0.95, 90 | epsilon: float = 1e-5): 91 | 92 | """ 93 | EMA ALGORITHM 94 | Each codebook entry is updated according to the encoder outputs who selected it. 95 | The important thing is that the codebook updating is not a loss term anymore. 96 | Specifically, for every codebook item wi, the mean code mi and usage count Ni are tracked: 97 | Ni ← Ni · γ + ni(1 − γ), 98 | mi ← mi · γ + Xnij e(xj )(1 − γ), 99 | wi ← mi Ni 100 | where γ is a discount factor 101 | 102 | :param num_embeddings: size of the latent dictionary (num of embedding vectors). 103 | :param embedding_dim: size of a single tensor in dictionary 104 | :param commitment_cost: scaling factor for e_loss 105 | :param decay: decay for EMA updating 106 | :param epsilon: smoothing parameters for EMA weights 107 | """ 108 | 109 | super().__init__(num_embeddings, embedding_dim) 110 | 111 | self.commitment_cost = commitment_cost 112 | 113 | # EMA does not require grad 114 | self.codebook.requires_grad_(False) 115 | 116 | # ema parameters 117 | # ema usage count: total count of each embedding trough epochs 118 | self.register_buffer('ema_count', torch.zeros(num_embeddings)) 119 | 120 | # same size as dict, initialized as codebook 121 | # the updated means 122 | self.register_buffer('ema_weight', torch.empty((self.num_embeddings, self.embedding_dim))) 123 | self.ema_weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) 124 | 125 | self.decay = decay 126 | self.epsilon = epsilon 127 | 128 | def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.IntTensor, float): 129 | """ 130 | :param x: tensors (output of the Encoder - B,D,H,W). 131 | :return quantized_x (B, D, H, W), detached codes (B, H*W), latent_loss 132 | """ 133 | 134 | b, c, h, w = x.shape 135 | device = x.device 136 | 137 | # Flat input to vectors of embedding dim = C. 138 | flat_x = rearrange(x, 'b c h w -> (b h w) c') 139 | 140 | # Calculate distances of each vector w.r.t the dict 141 | # distances is a matrix (B*H*W, codebook_size) 142 | distances = (torch.sum(flat_x ** 2, dim=1, keepdim=True) 143 | + torch.sum(self.codebook.weight ** 2, dim=1) 144 | - 2 * torch.matmul(flat_x, self.codebook.weight.t())) 145 | 146 | # Get indices of the closest vector in dict, and create a mask on the correct indexes 147 | # encoding_indices = (num_vectors_in_batch, 1) 148 | # Mask = (num_vectors_in_batch, codebook_dim) 149 | encoding_indices = torch.argmin(distances, dim=1) 150 | encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=device) 151 | encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) 152 | 153 | # Quantize and un-flat 154 | quantized = torch.matmul(encodings, self.codebook.weight) 155 | 156 | # Use EMA to update the embedding vectors 157 | # Update a codebook vector as the mean of the encoder outputs that are closer to it 158 | # Calculate the usage count of codes and the mean code, then update the codebook vector dividing the two 159 | if self.training: 160 | with torch.no_grad(): 161 | ema_count = self.get_buffer('ema_count') * self.decay + (1 - self.decay) * torch.sum(encodings, 0) 162 | 163 | # Laplace smoothing of the ema count 164 | self.ema_count = (ema_count + self.epsilon) / (b + self.num_embeddings * self.epsilon) * b 165 | 166 | dw = torch.matmul(encodings.t(), flat_x) 167 | self.ema_weight = self.get_buffer('ema_weight') * self.decay + (1 - self.decay) * dw 168 | 169 | self.codebook.weight.data = self.get_buffer('ema_weight') / self.get_buffer('ema_count').unsqueeze(1) 170 | 171 | # Loss function (only the inputs are updated) 172 | e_loss = self.commitment_cost * F.mse_loss(quantized.detach(), flat_x) 173 | 174 | # during backpropagation quantized = inputs (copy gradient trick) 175 | quantized = flat_x + (quantized - flat_x).detach() 176 | 177 | quantized = rearrange(quantized, '(b h w) c -> b c h w', b=b, h=h, w=w) 178 | encoding_indices = rearrange(encoding_indices, '(b h w)-> b (h w)', b=b, h=h, w=w).detach() 179 | 180 | return quantized, encoding_indices, e_loss 181 | 182 | @torch.no_grad() 183 | def vec_to_codes(self, x: torch.Tensor) -> torch.IntTensor: 184 | """ 185 | :param x: tensors (output of the Encoder - B,D,H,W). 186 | :return flat codebook indices (B, H * W) 187 | """ 188 | b, c, h, w = x.shape 189 | 190 | # Flat input to vectors of embedding dim = C. 191 | flat_x = rearrange(x, 'b c h w -> (b h w) c') 192 | 193 | # Calculate distances of each vector w.r.t the dict 194 | # distances is a matrix (B*H*W, codebook_size) 195 | distances = (torch.sum(flat_x ** 2, dim=1, keepdim=True) 196 | + torch.sum(self.codebook.weight ** 2, dim=1) 197 | - 2 * torch.matmul(flat_x, self.codebook.weight.t())) 198 | 199 | # Get indices of the closest vector in dict 200 | encoding_indices = torch.argmin(distances, dim=1) 201 | encoding_indices = rearrange(encoding_indices, '(b h w) -> b (h w)', b=b, h=h, w=w) 202 | 203 | return encoding_indices 204 | 205 | 206 | class GumbelVectorQuantizer(BaseVectorQuantizer): 207 | def __init__(self, num_embeddings: int, embedding_dim: int, straight_through: bool = False, temp: float = 1.0, 208 | kl_cost: float = 5e-4): 209 | """ 210 | :param num_embeddings: size of the latent dictionary (num of embedding vectors). 211 | :param embedding_dim: size of a single tensor in dict. 212 | :param straight_through: if True, will one-hot quantize, but still differentiate as if it is the soft sample 213 | :param temp: temperature parameter for gumbel softmax 214 | :param kl_cost: cost for kl divergence 215 | """ 216 | super().__init__(num_embeddings, embedding_dim) 217 | 218 | self.x_to_logits = torch.nn.Conv2d(num_embeddings, num_embeddings, 1) 219 | self.straight_through = straight_through 220 | self.temp = temp 221 | self.kl_cost = kl_cost 222 | 223 | def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.IntTensor, float): 224 | """ 225 | :param x: tensors (output of the Encoder - B,N,H,W). Note that N = number of embeddings in dict! 226 | :return quantized_x (B, D, H, W), detached codes (B, H*W), latent_loss 227 | """ 228 | 229 | # deterministic quantization during inference 230 | hard = self.straight_through if self.training else True 231 | 232 | logits = self.x_to_logits(x) 233 | soft_one_hot = F.gumbel_softmax(logits, tau=self.temp, dim=1, hard=hard) 234 | quantized = einsum(soft_one_hot, self.get_codebook(), 'b n h w, n d -> b d h w') 235 | 236 | # + kl divergence to the prior (uniform) loss, increase cb usage 237 | # Note: 238 | # KL(P(x), Q(x)) = sum_x (P(x) * log(P(x) / Q(x))) 239 | # in this case: P(x) is qy, Q(x) is uniform distribution (1 / num_embeddings) 240 | qy = F.softmax(logits, dim=1) 241 | kl_loss = self.kl_cost * torch.sum(qy * torch.log(qy * self.num_embeddings + 1e-10), dim=1).mean() 242 | 243 | encoding_indices = soft_one_hot.argmax(dim=1).detach() 244 | 245 | return quantized, encoding_indices, kl_loss 246 | 247 | def get_consts(self) -> (float, float): 248 | """ 249 | return temp, kl_cost 250 | """ 251 | return self.temp, self.kl_cost 252 | 253 | def set_consts(self, temp: float = None, kl_cost: float = None) -> None: 254 | """ 255 | update values for temp, kl_cost 256 | :param temp: new value for temperature (if not None) 257 | :param kl_cost: new value for kl_cost (if not None) 258 | """ 259 | if temp is not None: 260 | self.temp = temp 261 | 262 | if kl_cost is not None: 263 | self.kl_cost = kl_cost 264 | 265 | @torch.no_grad() 266 | def vec_to_codes(self, x: torch.Tensor) -> torch.IntTensor: 267 | """ 268 | :param x: tensors (output of the Encoder - B,N,H,W). Note that N = number of embeddings in dict! 269 | :return flat codebook indices (B, H * W) 270 | """ 271 | 272 | soft_one_hot = F.gumbel_softmax(x, tau=1.0, dim=1, hard=True) 273 | encoding_indices = soft_one_hot.argmax(dim=1) 274 | return encoding_indices 275 | 276 | 277 | class EntropyVectorQuantizer(BaseVectorQuantizer): 278 | 279 | def __init__(self, num_embeddings: int, embedding_dim: int, ent_loss_ratio: float = 0.1, 280 | ent_temperature: float = 0.01, ent_loss_type: str = 'softmax', commitment_cost: float = 0.25): 281 | 282 | super().__init__(num_embeddings, embedding_dim) 283 | 284 | # hparams 285 | self.ent_loss_ratio = ent_loss_ratio 286 | self.ent_temperature = ent_temperature 287 | self.ent_loss_type = ent_loss_type 288 | self.commitment_cost = commitment_cost 289 | 290 | def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.IntTensor, float): 291 | """ 292 | :param x: tensors (output of the Encoder - B,D,H,W). 293 | :return quantized_x (B, D, H, W), detached codes (B, H*W), latent_loss 294 | """ 295 | 296 | def entropy_loss(affinity: torch.Tensor, temperature: float, loss_type: str = 'softmax'): 297 | """ 298 | Increase codebook usage by maximizing entropy 299 | 300 | affinity: 2D tensor of size B * H * W, n_classes (codebook codes). 301 | if a vector is close to a codebook entry = higher affinity 302 | """ 303 | 304 | n_classes = affinity.shape[-1] 305 | 306 | affinity = torch.div(affinity, temperature) 307 | probs = F.softmax(affinity, dim=-1) 308 | 309 | if loss_type == "softmax": 310 | target_probs = probs 311 | elif loss_type == "argmax": 312 | codes = torch.argmax(affinity, dim=-1) 313 | one_hots = F.one_hot(codes, n_classes).to(codes) 314 | one_hots = probs - (probs - one_hots).detach() 315 | target_probs = one_hots 316 | else: 317 | raise ValueError("Entropy loss {} not supported".format(loss_type)) 318 | 319 | # compute entropy of the mean batch over codebook. 320 | avg_probs = torch.mean(target_probs, dim=0) 321 | avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) 322 | 323 | # entropy of samples 324 | log_probs = F.log_softmax(affinity + 1e-5, dim=-1) 325 | sample_entropy = -torch.sum(target_probs * log_probs, dim=-1) 326 | sample_entropy = torch.mean(sample_entropy) 327 | 328 | return sample_entropy - avg_entropy 329 | 330 | batch_size, c, h, w = x.shape 331 | 332 | # compute distances 333 | flat_x = rearrange(x, 'b c h w -> (b h w) c') 334 | transposed_cb_weights = self.get_codebook().T 335 | 336 | # final distance vector is (B * H * W, N Codebook Codes) 337 | a2 = torch.sum(flat_x ** 2, dim=1, keepdim=True) 338 | b2 = torch.sum(transposed_cb_weights ** 2, dim=0, keepdim=True) 339 | ab = torch.matmul(flat_x, transposed_cb_weights) 340 | distances = a2 - 2 * ab + b2 341 | 342 | # get indices and quantized 343 | encoding_indices = torch.argmin(distances, dim=1) 344 | quantized = self.codebook(encoding_indices) 345 | quantized = rearrange(quantized, '(b h w) c -> b c h w', b=batch_size, h=h, w=w) 346 | encoding_indices = rearrange(encoding_indices, '(b h w)-> b (h w)', b=batch_size, h=h, w=w).detach() 347 | 348 | # compute_loss 349 | e_latent_loss = torch.mean((quantized.detach() - x) ** 2) * self.commitment_cost 350 | q_latent_loss = torch.mean((quantized - x.detach()) ** 2) 351 | ent_loss = entropy_loss(-distances, self.ent_temperature, self.ent_loss_type) * self.ent_loss_ratio 352 | loss = e_latent_loss + q_latent_loss + ent_loss 353 | 354 | quantized = x + (quantized - x).detach() 355 | 356 | return quantized, encoding_indices, loss 357 | 358 | @torch.no_grad() 359 | def vec_to_codes(self, x: torch.Tensor) -> torch.IntTensor: 360 | """ 361 | :param x: tensors (output of the Encoder - B,D,H,W). 362 | :return flat codebook indices (B, H * W) 363 | """ 364 | 365 | batch_size, c, h, w = x.shape 366 | 367 | # compute distances 368 | flat_x = rearrange(x, 'b c h w -> (b h w) c') 369 | transposed_cb_weights = self.get_codebook().T 370 | 371 | # final distance vector is (B * Latent_Dim, Codebook Dim) 372 | a2 = torch.sum(flat_x ** 2, dim=1, keepdim=True) 373 | b2 = torch.sum(transposed_cb_weights ** 2, dim=0, keepdim=True) 374 | ab = torch.matmul(flat_x, transposed_cb_weights) 375 | distances = a2 - 2 * ab + b2 376 | 377 | # get indices and quantized 378 | encoding_indices = torch.argmin(distances, dim=1) 379 | encoding_indices = rearrange(encoding_indices, '(b h w)-> b (h w)', b=batch_size, h=h, w=w) 380 | 381 | return encoding_indices 382 | -------------------------------------------------------------------------------- /vqvae/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 4 | from pytorch_lightning.loggers.wandb import WandbLogger 5 | from pytorch_lightning.strategies import DDPStrategy 6 | 7 | from vqvae.common_utils import set_matmul_precision, get_model_conf, get_datamodule 8 | from vqvae.model import VQVAE 9 | 10 | import argparse 11 | import math 12 | import os.path 13 | import wandb 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument('--params_file', type=str, required=True, help='path to yaml file with model params') 20 | parser.add_argument('--dataloader', type=str, choices=['standard', 'ffcv'], default='standard', 21 | help='defines what type of dataloader to use.') 22 | parser.add_argument('--dataset_path', type=str, required=True, 23 | help='path to a dataset folder containing two sub-folders (validation / train) or beton files ' 24 | '(train.beton / validation.beton).') 25 | parser.add_argument('--save_path', type=str, required=True, help='path for checkpointing the model') 26 | parser.add_argument('--save_every_n_epochs', type=int, default=1, help='how often to save a new checkpoint') 27 | parser.add_argument('--run_name', type=str, required=True, 28 | help='name of the run, for wandb logging and checkpointing') 29 | parser.add_argument('--seed', type=int, required=True, help='global random seed for reproducibility') 30 | parser.add_argument('--loading_path', type=str, 31 | help='if passed, will load and continue training of an existing checkpoint', default=None) 32 | parser.add_argument('--logging', help='if passed, wandb logger is used', action='store_true') 33 | parser.add_argument('--wandb_project', type=str, help='project name for wandb logger', default='vqvae') 34 | parser.add_argument('--wandb_id', type=str, 35 | help='wandb id of the run. Useful for resuming logging of a model', default=None) 36 | parser.add_argument('--workers', type=int, help='num of parallel workers', default=1) 37 | parser.add_argument('--num_nodes', type=int, help='number of gpu nodes used for training', default=1) 38 | 39 | return parser.parse_args() 40 | 41 | 42 | def main(): 43 | 44 | # only for A100 45 | set_matmul_precision() 46 | 47 | args = parse_args() 48 | conf = get_model_conf(args.params_file) 49 | 50 | # configuration params (assumes some env variables in case of multi-node setup) 51 | gpus = torch.cuda.device_count() 52 | num_nodes = args.num_nodes 53 | rank = int(os.getenv('NODE_RANK')) if os.getenv('NODE_RANK') is not None else 0 54 | is_dist = gpus > 1 or num_nodes > 1 55 | 56 | workers = int(args.workers) 57 | seed = int(args.seed) 58 | 59 | cumulative_batch_size = int(conf['training']['cumulative_bs']) 60 | batch_size_per_device = cumulative_batch_size // (num_nodes * gpus) 61 | 62 | base_learning_rate = float(conf['training']['base_lr']) 63 | learning_rate = base_learning_rate * math.sqrt(cumulative_batch_size / 256) 64 | 65 | max_epochs = int(conf['training']['max_epochs']) 66 | 67 | pl.seed_everything(seed, workers=True) 68 | 69 | # logging stuff, checkpointing and resume 70 | log_to_wandb = bool(args.logging) 71 | project_name = str(args.wandb_project) 72 | wandb_id = args.wandb_id 73 | 74 | run_name = str(args.run_name) 75 | save_checkpoint_dir = f'{args.save_path}{run_name}/' 76 | save_every_n_epochs = int(args.save_every_n_epochs) 77 | 78 | load_checkpoint_path = args.loading_path 79 | resume = load_checkpoint_path is not None 80 | 81 | if rank == 0: # prevents from logging multiple times 82 | logger = WandbLogger(project=project_name, name=run_name, offline=not log_to_wandb, id=wandb_id, 83 | resume='must' if resume else None) 84 | else: 85 | logger = WandbLogger(project=project_name, name=run_name, offline=True) 86 | 87 | # model params 88 | image_size = int(conf['image_size']) 89 | ae_conf = conf['autoencoder'] 90 | q_conf = conf['quantizer'] 91 | l_conf = conf['loss'] if 'loss' in conf.keys() else None 92 | t_conf = {'lr': learning_rate, 93 | 'betas': conf['training']['betas'], 94 | 'eps': conf['training']['eps'], 95 | 'weight_decay': conf['training']['weight_decay'], 96 | 'warmup_epochs': conf['training']['warmup_epochs'] if 'warmup_epochs' in conf['training'].keys() else None, 97 | 'decay_epochs': conf['training']['decay_epochs'] if 'decay_epochs' in conf['training'].keys() else None, 98 | } 99 | 100 | # check if using adversarial loss 101 | use_adversarial = (l_conf is not None 102 | and 'adversarial_params' in l_conf.keys() 103 | and l_conf['adversarial_params'] is not None) 104 | 105 | # get model 106 | if resume: 107 | # image_size: int, ae_conf: dict, q_conf: dict, l_conf: dict, t_conf: dict, init_cb: bool = True, 108 | # load_loss: bool = True 109 | model = VQVAE.load_from_checkpoint(load_checkpoint_path, strict=False, 110 | image_size=image_size, ae_conf=ae_conf, q_conf=q_conf, l_conf=l_conf, 111 | t_conf=t_conf, init_cb=False, load_loss=True) 112 | else: 113 | model = VQVAE(image_size=image_size, ae_conf=ae_conf, q_conf=q_conf, l_conf=l_conf, t_conf=t_conf, 114 | init_cb=True, load_loss=True) 115 | 116 | # data loading (standard pytorch lightning or ffcv) 117 | datamodule = get_datamodule(args.dataloader, args.dataset_path, image_size, batch_size_per_device, 118 | workers, seed, is_dist, mode='train') 119 | 120 | # callbacks 121 | checkpoint_callback = ModelCheckpoint(dirpath=save_checkpoint_dir, filename='{epoch:02d}', save_last=True, 122 | save_top_k=-1, every_n_epochs=save_every_n_epochs) 123 | 124 | callbacks = [LearningRateMonitor(), checkpoint_callback] 125 | 126 | # trainer 127 | # set find unused parameters if using vqgan (adversarial training) 128 | trainer = pl.Trainer(strategy=DDPStrategy(find_unused_parameters=use_adversarial, static_graph=not use_adversarial), 129 | accelerator='gpu', num_nodes=num_nodes, devices=gpus, precision='16-mixed', 130 | callbacks=callbacks, deterministic=True, logger=logger, 131 | max_epochs=max_epochs, check_val_every_n_epoch=5) 132 | 133 | print(f"[INFO] workers: {workers}") 134 | print(f"[INFO] batch size per device: {batch_size_per_device}") 135 | print(f"[INFO] cumulative batch size (all devices): {cumulative_batch_size}") 136 | print(f"[INFO] final learning rate: {learning_rate}") 137 | 138 | # check to prevent later error 139 | if use_adversarial and batch_size_per_device % 4 != 0: 140 | raise RuntimeError('batch size per device must be divisible by 4! (due to stylegan discriminator forward pass)') 141 | 142 | trainer.fit(model, datamodule, ckpt_path=load_checkpoint_path) 143 | 144 | # ensure wandb has stopped logging 145 | wandb.finish() 146 | 147 | 148 | if __name__ == '__main__': 149 | main() 150 | --------------------------------------------------------------------------------