├── LICENSE ├── README.md ├── data ├── __init__.py ├── autoaugment.py ├── cifar10.py ├── cifar100.py ├── cifar100s.py ├── cifar10s.py ├── idbh.py ├── semisup.py ├── svhn.py ├── svhns.py └── tiny_imagenet.py ├── dnnlib ├── __init__.py └── util.py ├── main_aug.py ├── main_simulation_bGMM.py ├── main_train_CDCGAN.py ├── main_train_aug.py ├── models ├── __init__.py ├── cdcgan.py ├── gaussnb.py ├── preact_resnet.py ├── preact_resnetwithswish.py ├── resnet.py ├── stylegan.py ├── ti_preact_resnet.py ├── wideresnet.py └── wideresnetwithswish.py ├── plot.ipynb ├── pytorch.yaml ├── scripts ├── main_aug.sh ├── main_bGMM.sh └── main_train_aug.sh ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py └── utils ├── __init__.py ├── legacy.py ├── logger.py ├── parser.py ├── rst.py ├── tools.py ├── train.py └── vis.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ML Group @ RUC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Toward Understanding Generative Data Augmentation 2 | 3 | This is the official implementation for [Toward Understanding Generative Data Augmentation](https://arxiv.org/abs/2305.17476). 4 | 5 | ## Dependencies 6 | 7 | ```bash 8 | conda env create -f pytorch.yaml 9 | ``` 10 | 11 | ## Simulation experiments on bGMM 12 | 13 | * Reproduce the simulation results in Figure. 1 14 | 15 | ```bash 16 | bash scripts/main_bGMM.sh 17 | ``` 18 | 19 | * Use the code in plot.ipynb to plot the results 20 | 21 | ## Empirical experiments on CIFAR-10 22 | 23 | ### Obtain weights of deep generative models 24 | 25 | * cDCGAN: 26 | 27 | ```bash 28 | python ./main_train_CDCGAN.py # hyperparameters have been set 29 | ``` 30 | 31 | * StyleGANV2-ADA: 32 | 33 | ```bash 34 | wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl 35 | ``` 36 | 37 | ### Generate and store new images 38 | 39 | * cDCGAN and StyleGAN2-ADA: 40 | 41 | ```bash 42 | bash scripts/main_aug.sh 43 | ``` 44 | 45 | * EDM: 46 | 47 | ```bash 48 | wget https://huggingface.co/datasets/P2333/DM-Improves-AT/resolve/main/cifar10/5m.npz 49 | ``` 50 | 51 | ### Train ResNets with GDA (see scripts/main_train_aug.sh) 52 | 53 | * cDCGAN and StyleGAN2-ADA: 54 | 55 | ```bash 56 | export CUDA_VISIBLE_DEVICES=0 57 | export PYTHONPATH=$PYTHONPATH:'pwd' 58 | python main_train_aug.py --data-dir ./datasets \ 59 | --log-dir ./log/RN18_cDCGAN_base \ 60 | --desc RN18_cifar10s_lr0p2_epoch100_bs512_1000k \ # dependent on m_G, here m_G = 1M 61 | --data cifar10s \ 62 | --batch-size 512 \ 63 | --model resnet18 \ 64 | --num-epochs 100 \ 65 | --eval-freq 10 \ 66 | --lr 0.2 \ 67 | --aux-data-filename ./datasets/cDCGAN/1000k.npz # dependent on m_G, here m_G = 1M 68 | --augment base \ # none if do not use standard augmentation 69 | ``` 70 | 71 | * EDM: 72 | 73 | ```bash 74 | export CUDA_VISIBLE_DEVICES=0 75 | export PYTHONPATH=$PYTHONPATH:'pwd' 76 | python main_train_aug.py --data-dir ./datasets \ 77 | --log-dir ./log/RN18_EDM_base \ 78 | --desc RN18_cifar10s_lr0p2_epoch100_bs512_1000k \ 79 | --data cifar10s \ 80 | --batch-size 512 \ 81 | --model resnet18 \ 82 | --num-epochs 100 \ 83 | --eval-freq 10 \ 84 | --lr 0.2 \ 85 | --aux-data-filename ../bishe/codes/data/5m.npz \ # dir of the downloaded EDM data 86 | --aux-take-amount 1000000 # m_G 87 | --augment base \ # none if do not use standard augmentation 88 | ``` 89 | 90 | ## Acknowledgments 91 | 92 | The code is developed based on the following repositories. We appreciate their nice implementations. 93 | 94 | | Method | Repository | 95 | | :-----------------: | :----------------------------------------------------------: | 96 | | CDCGAN | https://github.com/znxlwm/pytorch-MNIST-CelebA-cGAN-cDCGAN | 97 | | StyleGAN | https://github.com/NVlabs/stylegan2-ada-pytorch | 98 | | EDM data & training | https://github.com/wzekai99/DM-Improves-AT | 99 | | bGMM | https://github.com/ML-GSAI/Revisiting-Dis-vs-Gen-Classifiers | 100 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader 4 | 5 | from .cifar10 import load_cifar10 6 | from .cifar100 import load_cifar100 7 | from .svhn import load_svhn 8 | from .cifar10s import load_cifar10s 9 | from .cifar100s import load_cifar100s 10 | from .svhns import load_svhns 11 | from .tiny_imagenet import load_tinyimagenet 12 | 13 | from .semisup import get_semisup_dataloaders 14 | 15 | 16 | SEMISUP_DATASETS = ['cifar10s', 'cifar100s', 'svhns'] 17 | DATASETS = ['cifar10', 'svhn', 'cifar100', 'tiny-imagenet'] + SEMISUP_DATASETS 18 | 19 | _LOAD_DATASET_FN = { 20 | 'cifar10': load_cifar10, 21 | 'cifar100': load_cifar100, 22 | 'svhn': load_svhn, 23 | 'tiny-imagenet': load_tinyimagenet, 24 | 'cifar10s': load_cifar10s, #this 25 | 'cifar100s': load_cifar100s, 26 | 'svhns': load_svhns, 27 | } 28 | 29 | 30 | def get_data_info(data_dir): 31 | """ 32 | Returns dataset information. 33 | Arguments: 34 | data_dir (str): path to data directory. 35 | """ 36 | dataset = os.path.basename(os.path.normpath(data_dir)) 37 | if 'cifar100' in data_dir: 38 | from .cifar100 import DATA_DESC 39 | elif 'cifar10' in data_dir: 40 | from .cifar10 import DATA_DESC #this 41 | elif 'svhn' in data_dir: 42 | from .svhn import DATA_DESC 43 | elif 'tiny-imagenet' in data_dir: 44 | from .tiny_imagenet import DATA_DESC 45 | else: 46 | raise ValueError(f'Only data in {DATASETS} are supported!') 47 | DATA_DESC['data'] = dataset 48 | return DATA_DESC 49 | 50 | 51 | def load_data(data_dir, batch_size=512, batch_size_test=256, num_workers=4, use_augmentation='none', use_consistency=False, shuffle_train=True, 52 | aux_data_filename=None, unsup_fraction=None, aux_take_amount=None, validation=False): 53 | """ 54 | Returns train, test datasets and dataloaders. 55 | Arguments: 56 | data_dir (str): path to data directory. 57 | batch_size (int): batch size for training. 58 | batch_size_test (int): batch size for validation. 59 | num_workers (int): number of workers for loading the data. 60 | use_augmentation (base/none): whether to use augmentations for training set. 61 | shuffle_train (bool): whether to shuffle training set. 62 | aux_data_filename (str): path to unlabelled data. 63 | unsup_fraction (float): fraction of unlabelled data per batch. 64 | aux_take_amout (int): number of augmentation. 65 | validation (bool): if True, also returns a validation dataloader for unspervised cifar10 (as in Gowal et al, 2020). 66 | """ 67 | dataset = os.path.basename(os.path.normpath(data_dir)) 68 | load_dataset_fn = _LOAD_DATASET_FN[dataset] 69 | 70 | if validation: 71 | assert dataset in SEMISUP_DATASETS, 'Only semi-supervised datasets allow a validation set.' 72 | train_dataset, test_dataset, val_dataset = load_dataset_fn(data_dir=data_dir, use_augmentation=use_augmentation, use_consistency=use_consistency, 73 | aux_data_filename=aux_data_filename, aux_take_amount=aux_take_amount, validation=True) 74 | else: 75 | train_dataset, test_dataset = load_dataset_fn(data_dir=data_dir, use_augmentation=use_augmentation, aux_data_filename=aux_data_filename, aux_take_amount=aux_take_amount) 76 | 77 | if dataset in SEMISUP_DATASETS: 78 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=False) 79 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, num_workers=4, pin_memory=False) 80 | 81 | if validation: 82 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size_test, shuffle=False, num_workers=4, pin_memory=False) 83 | 84 | # if validation: 85 | # train_dataloader, test_dataloader, val_dataloader = get_semisup_dataloaders( 86 | # train_dataset, test_dataset, val_dataset, batch_size, batch_size_test, num_workers, unsup_fraction 87 | # ) 88 | # else: 89 | # train_dataloader, test_dataloader = get_semisup_dataloaders( 90 | # train_dataset, test_dataset, None, batch_size, batch_size_test, num_workers, unsup_fraction 91 | # ) 92 | 93 | else: 94 | pin_memory = False 95 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle_train, 96 | num_workers=num_workers, pin_memory=pin_memory) 97 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, 98 | num_workers=num_workers, pin_memory=pin_memory) 99 | 100 | if validation: 101 | return train_dataset, test_dataset, val_dataset, train_dataloader, test_dataloader, val_dataloader 102 | 103 | print(len(train_dataset), len(test_dataset)) 104 | return train_dataset, test_dataset, train_dataloader, test_dataloader 105 | -------------------------------------------------------------------------------- /data/cifar10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | #this 7 | DATA_DESC = { 8 | 'data': 'cifar10', 9 | 'classes': ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'), 10 | 'num_classes': 10, 11 | 'mean': [0.48145466, 0.4578275, 0.40821073], 12 | 'std': [0.26862954, 0.26130258, 0.27577711], 13 | } 14 | 15 | 16 | def load_cifar10(data_dir, use_augmentation='base'): 17 | """ 18 | Returns CIFAR10 train, test datasets and dataloaders. 19 | Arguments: 20 | data_dir (str): path to data directory. 21 | use_augmentation (base/none): whether to use augmentations for training set. 22 | Returns: 23 | train dataset, test dataset. 24 | """ 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | if use_augmentation == 'base': 27 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 28 | transforms.ToTensor()]) 29 | else: 30 | train_transform = test_transform 31 | 32 | train_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform) 33 | test_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=test_transform) 34 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /data/cifar100.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | 7 | DATA_DESC = { 8 | 'data': 'cifar100', 9 | 'classes': tuple(range(0, 100)), 10 | 'num_classes': 100, 11 | 'mean': [0.5071, 0.4865, 0.4409], 12 | 'std': [0.2673, 0.2564, 0.2762], 13 | } 14 | 15 | 16 | def load_cifar100(data_dir, use_augmentation='base'): 17 | """ 18 | Returns CIFAR100 train, test datasets and dataloaders. 19 | Arguments: 20 | data_dir (str): path to data directory. 21 | use_augmentation (base/none): whether to use augmentations for training set. 22 | Returns: 23 | train dataset, test dataset. 24 | """ 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | if use_augmentation == 'base': 27 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 28 | transforms.RandomRotation(15), transforms.ToTensor()]) 29 | else: 30 | train_transform = test_transform 31 | 32 | train_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform) 33 | test_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=False, download=True, transform=test_transform) 34 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /data/cifar100s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | import re 7 | import numpy as np 8 | 9 | from .semisup import SemiSupervisedDataset 10 | 11 | 12 | def load_cifar100s(data_dir, use_augmentation='base', use_consistency=False, aux_take_amount=None, 13 | aux_data_filename=None, validation=False): 14 | """ 15 | Returns semisupervised CIFAR100 train, test datasets and dataloaders (with DDPM Images). 16 | Arguments: 17 | data_dir (str): path to data directory. 18 | use_augmentation (base/none): whether to use augmentations for training set. 19 | aux_take_amount (int): number of semi-supervised examples to use (if None, use all). 20 | aux_data_filename (str): path to additional data pickle file. 21 | Returns: 22 | train dataset, test dataset. 23 | """ 24 | data_dir = re.sub('cifar100s', 'cifar100', data_dir) 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | if use_augmentation == 'base': 27 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 28 | transforms.RandomRotation(15), transforms.ToTensor()]) 29 | else: 30 | train_transform = test_transform 31 | 32 | train_dataset = SemiSupervisedCIFAR100(base_dataset='cifar100', root=data_dir, train=True, download=True, 33 | transform=train_transform, aux_data_filename=aux_data_filename, 34 | add_aux_labels=True, aux_take_amount=aux_take_amount, validation=validation) 35 | test_dataset = SemiSupervisedCIFAR100(base_dataset='cifar100', root=data_dir, train=False, download=True, 36 | transform=test_transform) 37 | if validation: 38 | val_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=test_transform) 39 | val_dataset = torch.utils.data.Subset(val_dataset, np.arange(0, 1024)) 40 | return train_dataset, test_dataset, val_dataset 41 | return train_dataset, test_dataset, None 42 | 43 | 44 | class SemiSupervisedCIFAR100(SemiSupervisedDataset): 45 | """ 46 | A dataset with auxiliary pseudo-labeled data for CIFAR100. 47 | """ 48 | def load_base_dataset(self, train=False, **kwargs): 49 | assert self.base_dataset == 'cifar100', 'Only semi-supervised cifar100 is supported. Please use correct dataset!' 50 | self.dataset = torchvision.datasets.CIFAR100(train=train, **kwargs) 51 | self.dataset_size = len(self.dataset) -------------------------------------------------------------------------------- /data/cifar10s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | import re 7 | import numpy as np 8 | 9 | from .semisup import SemiSupervisedDataset 10 | from .semisup import SemiSupervisedSampler 11 | 12 | from .cifar10 import DATA_DESC 13 | 14 | from .autoaugment import CIFAR10Policy 15 | from .idbh import IDBH 16 | # from RandAugment import RandAugment # pip install git+https://github.com/ildoonet/pytorch-randaugment 17 | 18 | 19 | class CutoutDefault(object): 20 | """ 21 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 22 | """ 23 | def __init__(self, length): 24 | self.length = length 25 | 26 | def __call__(self, img): 27 | h, w = img.size(1), img.size(2) 28 | mask = np.ones((h, w), np.float32) 29 | y = np.random.randint(h) 30 | x = np.random.randint(w) 31 | 32 | y1 = np.clip(y - self.length // 2, 0, h) 33 | y2 = np.clip(y + self.length // 2, 0, h) 34 | x1 = np.clip(x - self.length // 2, 0, w) 35 | x2 = np.clip(x + self.length // 2, 0, w) 36 | 37 | mask[y1: y2, x1: x2] = 0. 38 | mask = torch.from_numpy(mask) 39 | mask = mask.expand_as(img) 40 | img *= mask 41 | return img 42 | 43 | 44 | class MultiDataTransform(object): 45 | def __init__(self, transform): 46 | self.transform = transform 47 | 48 | def __call__(self, sample): 49 | x1 = self.transform(sample) 50 | x2 = self.transform(sample) 51 | return x1, x2 52 | 53 | 54 | def load_cifar10s(data_dir, use_augmentation='none', use_consistency=False, aux_take_amount=None, 55 | aux_data_filename='/cluster/scratch/rarade/cifar10s/ti_500K_pseudo_labeled.pickle', 56 | validation=False): 57 | """ 58 | Returns semisupervised CIFAR10 train, test datasets and dataloaders (with Tiny Images). 59 | Arguments: 60 | data_dir (str): path to data directory. 61 | use_augmentation: use different augmentations for training set. 62 | aux_take_amount (int): number of semi-supervised examples to use (if None, use all). 63 | aux_data_filename (str): path to additional data pickle file. 64 | Returns: 65 | train dataset, test dataset. 66 | """ 67 | data_dir = re.sub('cifar10s', 'cifar10', data_dir) 68 | 69 | test_transform = transforms.Compose([ 70 | transforms.ToTensor(), 71 | transforms.Normalize(DATA_DESC['mean'], DATA_DESC['std']) 72 | ]) 73 | 74 | if use_augmentation == 'none': # case1 75 | train_transform = test_transform 76 | elif use_augmentation == 'base': # case2 77 | train_transform = transforms.Compose([ 78 | transforms.RandomCrop(32, padding=4), 79 | transforms.RandomHorizontalFlip(0.5), 80 | transforms.ToTensor(), 81 | transforms.Normalize(DATA_DESC['mean'], DATA_DESC['std']) 82 | ]) 83 | elif use_augmentation == 'cutout': 84 | train_transform = transforms.Compose([ 85 | transforms.RandomCrop(32, padding=4), 86 | transforms.RandomHorizontalFlip(0.5), 87 | transforms.ToTensor(), 88 | ]) 89 | train_transform.transforms.append(CutoutDefault(18)) 90 | elif use_augmentation == 'autoaugment': 91 | train_transform = transforms.Compose([ 92 | transforms.RandomCrop(32, padding=4), 93 | transforms.RandomHorizontalFlip(0.5), 94 | CIFAR10Policy(), 95 | transforms.ToTensor(), 96 | ]) 97 | train_transform.transforms.append(CutoutDefault(18)) 98 | elif use_augmentation == 'randaugment': 99 | train_transform = transforms.Compose([ 100 | transforms.RandomCrop(32, padding=4), 101 | transforms.RandomHorizontalFlip(0.5), 102 | transforms.ToTensor(), 103 | ]) 104 | # Add RandAugment with N, M(hyperparameter), N=2, M=14 for wdn-28-10 105 | # train_transform.transforms.insert(0, RandAugment(2, 14)) 106 | elif use_augmentation == 'idbh': 107 | train_transform = IDBH('cifar10-weak') 108 | 109 | if use_consistency: 110 | train_transform = MultiDataTransform(train_transform) 111 | 112 | train_dataset = SemiSupervisedCIFAR10(base_dataset='cifar10', root=data_dir, train=True, download=True, 113 | transform=train_transform, aux_data_filename=aux_data_filename, 114 | add_aux_labels=True, aux_take_amount=aux_take_amount, validation=validation) 115 | test_dataset = SemiSupervisedCIFAR10(base_dataset='cifar10', root=data_dir, train=False, download=True, 116 | transform=test_transform) 117 | if validation: 118 | val_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=test_transform) 119 | val_dataset = torch.utils.data.Subset(val_dataset, np.arange(0, 1024)) # split from training set 120 | return train_dataset, test_dataset, val_dataset 121 | return train_dataset, test_dataset 122 | 123 | 124 | class SemiSupervisedCIFAR10(SemiSupervisedDataset): 125 | """ 126 | A dataset with auxiliary pseudo-labeled data for CIFAR10. 127 | """ 128 | def load_base_dataset(self, train=False, **kwargs): 129 | assert self.base_dataset == 'cifar10', 'Only semi-supervised cifar10 is supported. Please use correct dataset!' 130 | self.dataset = torchvision.datasets.CIFAR10(train=train, **kwargs) 131 | self.dataset_size = len(self.dataset) 132 | -------------------------------------------------------------------------------- /data/idbh.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch as tc 3 | import torchvision.transforms as T 4 | from torchvision.transforms import functional as F 5 | from torchvision.transforms import InterpolationMode as Interpolation 6 | 7 | class IDBH(tc.nn.Module): 8 | def __init__(self, version): 9 | super().__init__() 10 | if version == 'cifar10-weak': 11 | layers = [ 12 | T.RandomHorizontalFlip(), 13 | CropShift(0, 11), 14 | ColorShape('color'), 15 | T.ToTensor(), 16 | T.RandomErasing(p=0.5) 17 | ] 18 | elif version == 'cifar10-strong': 19 | layers = [ 20 | T.RandomHorizontalFlip(), 21 | CropShift(0, 11), 22 | ColorShape('color'), 23 | T.ToTensor(), 24 | T.RandomErasing(p=1) 25 | ] 26 | elif version == 'svhn': 27 | layers = [ 28 | T.RandomHorizontalFlip(), 29 | CropShift(0, 9), 30 | ColorShape('shape'), 31 | T.ToTensor(), 32 | T.RandomErasing(p=1, scale=(0.02, 0.5)) 33 | ] 34 | else: 35 | raise Exception("IDBH: invalid version string") 36 | 37 | self.layers = T.Compose(layers) 38 | 39 | def forward(self, img): 40 | return self.layers(img) 41 | 42 | 43 | class ColorShape(tc.nn.Module): 44 | ColorBiased = [ 45 | (0.125, 'color', 0.1, 1.9), 46 | (0.125, 'brightness', 0.5, 1.9), 47 | (0.125, 'contrast', 0.5, 1.9), 48 | (0.125, 'sharpness', 0.1, 1.9), 49 | (0.125, 'autocontrast'), 50 | (0.125, 'equalize'), 51 | (0.125, 'shear', 0.05, 0.15), 52 | (0.125, 'rotate', 1, 11) 53 | ] 54 | ShapeBiased = [ 55 | (0.08, 'color', 0.1, 1.9), 56 | (0.08, 'brightness', 0.5, 1.9), 57 | (0.04, 'contrast', 0.5, 1.9), 58 | (0.08, 'sharpness', 0.1, 1.9), 59 | (0.04, 'autocontrast'), 60 | (0.08, 'equalize'), 61 | (0.3, 'shear', 0.05, 0.35), 62 | (0.3, 'rotate', 1, 31) 63 | ] 64 | 65 | def __init__(self, version='color'): 66 | super().__init__() 67 | 68 | assert version in ['color', 'shape'] 69 | space = self.ColorBiased if version == 'color' else self.ShapeBiased 70 | 71 | self.space = {} 72 | p_accu = 0.0 73 | for trans in space: 74 | p = trans[0] 75 | self.space[(p_accu, p_accu+p)] = trans[1:] 76 | p_accu += p 77 | 78 | def transform(self, img, trans): 79 | if len(trans) == 1: 80 | trans = trans[0] 81 | else: 82 | lower, upper = trans[1:] 83 | trans = trans[0] 84 | if trans == 'rotate': 85 | strength = tc.randint(lower, upper, (1,)).item() 86 | else: 87 | strength = tc.rand(1) * (upper-lower) + lower 88 | 89 | if trans == 'color': 90 | img = F.adjust_saturation(img, strength) 91 | elif trans == 'brightness': 92 | img = F.adjust_brightness(img, strength) 93 | elif trans == 'contrast': 94 | img = F.adjust_contrast(img, strength) 95 | elif trans == 'sharpness': 96 | img = F.adjust_sharpness(img, strength) 97 | elif trans == 'shear': 98 | if tc.randint(2, (1,)): 99 | # random sign 100 | strength *= -1 101 | strength = math.degrees(strength) 102 | strength = [strength, 0.0] if tc.randint(2, (1,)) else [0.0, strength] 103 | img = F.affine(img, 104 | angle=0.0, 105 | translate=[0, 0], 106 | scale=1.0, 107 | shear=strength, 108 | interpolation=Interpolation.NEAREST, 109 | fill=0) 110 | elif trans == 'rotate': 111 | if tc.randint(2, (1,)): 112 | strength *= -1 113 | img = F.rotate(img, angle=strength, interpolation=Interpolation.NEAREST, fill=0) 114 | elif trans == 'autocontrast': 115 | img = F.autocontrast(img) 116 | elif trans == 'equalize': 117 | img = F.equalize(img) 118 | 119 | return img 120 | 121 | def forward(self, img): 122 | roll = tc.rand(1) 123 | for (lower, upper), trans in self.space.items(): 124 | if roll <= upper and roll >= lower: 125 | return self.transform(img, trans) 126 | 127 | return img 128 | 129 | class CropShift(tc.nn.Module): 130 | def __init__(self, low, high=None): 131 | super().__init__() 132 | high = low if high is None else high 133 | self.low, self.high = int(low), int(high) 134 | 135 | def sample_top(self, x, y): 136 | x = tc.randint(0, x+1, (1,)).item() 137 | y = tc.randint(0, y+1, (1,)).item() 138 | return x, y 139 | 140 | def forward(self, img): 141 | if self.low == self.high: 142 | strength = self.low 143 | else: 144 | strength = tc.randint(self.low, self.high, (1,)).item() 145 | 146 | w, h = F.get_image_size(img) 147 | crop_x = tc.randint(0, strength+1, (1,)).item() 148 | crop_y = strength - crop_x 149 | crop_w, crop_h = w - crop_x, h - crop_y 150 | 151 | top_x, top_y = self.sample_top(crop_x, crop_y) 152 | 153 | img = F.crop(img, top_y, top_x, crop_h, crop_w) 154 | img = F.pad(img, padding=[crop_x, crop_y], fill=0) 155 | 156 | top_x, top_y = self.sample_top(crop_x, crop_y) 157 | 158 | return F.crop(img, top_y, top_x, h, w) 159 | -------------------------------------------------------------------------------- /data/semisup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | 5 | import torch 6 | 7 | 8 | def get_semisup_dataloaders(train_dataset, test_dataset, val_dataset=None, batch_size=256, batch_size_test=256, num_workers=4, 9 | unsup_fraction=0.5): 10 | """ 11 | Return dataloaders with custom sampling of pseudo-labeled data. 12 | """ 13 | dataset_size = train_dataset.dataset_size 14 | train_batch_sampler = SemiSupervisedSampler(train_dataset.sup_indices, train_dataset.unsup_indices, batch_size, 15 | unsup_fraction, num_batches=int(np.ceil(dataset_size/batch_size))) 16 | epoch_size = len(train_batch_sampler) * batch_size 17 | 18 | # kwargs = {'num_workers': num_workers, 'pin_memory': torch.cuda.is_available() } 19 | kwargs = {'num_workers': num_workers, 'pin_memory': False} 20 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_batch_sampler, **kwargs) 21 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, **kwargs) 22 | 23 | if val_dataset: 24 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size_test, shuffle=False, **kwargs) 25 | return train_dataloader, test_dataloader, val_dataloader 26 | return train_dataloader, test_dataloader 27 | 28 | 29 | class SemiSupervisedDataset(torch.utils.data.Dataset): 30 | """ 31 | A dataset with auxiliary pseudo-labeled data. 32 | """ 33 | def __init__(self, base_dataset='cifar10', take_amount=None, take_amount_seed=13, aux_data_filename=None, 34 | add_aux_labels=False, aux_take_amount=None, train=False, validation=False, **kwargs): 35 | 36 | self.base_dataset = base_dataset 37 | self.load_base_dataset(train, **kwargs) 38 | 39 | 40 | if validation: 41 | self.dataset.data = self.dataset.data[1024:] 42 | self.dataset.targets = self.dataset.targets[1024:] 43 | 44 | self.train = train 45 | 46 | if self.train: 47 | if take_amount is not None: 48 | rng_state = np.random.get_state() 49 | np.random.seed(take_amount_seed) 50 | take_inds = np.random.choice(len(self.sup_indices), take_amount, replace=False) 51 | np.random.set_state(rng_state) 52 | 53 | self.targets = self.targets[take_inds] 54 | self.data = self.data[take_inds] 55 | 56 | self.sup_indices = list(range(len(self.targets))) 57 | self.unsup_indices = [] 58 | 59 | if aux_data_filename is not None: 60 | aux_path = aux_data_filename 61 | print('Loading data from %s' % aux_path) 62 | if os.path.splitext(aux_path)[1] == '.pickle': 63 | # for data from Carmon et al, 2019. 64 | with open(aux_path, 'rb') as f: 65 | aux = pickle.load(f) 66 | aux_data = aux['data'] 67 | aux_targets = aux['extrapolated_targets'] 68 | else: 69 | # for data from Rebuffi et al, 2021. 70 | aux = np.load(aux_path) 71 | aux_data = aux['image'] 72 | print(aux_data.shape) 73 | aux_targets = aux['label'] 74 | 75 | orig_len = len(self.data) 76 | 77 | if aux_take_amount is not None: 78 | rng_state = np.random.get_state() 79 | np.random.seed(take_amount_seed) 80 | take_inds = np.random.choice(len(aux_data), aux_take_amount, replace=False) 81 | np.random.set_state(rng_state) 82 | 83 | aux_data = aux_data[take_inds] 84 | aux_targets = aux_targets[take_inds] 85 | 86 | self.data = np.concatenate((self.data, aux_data), axis=0) 87 | 88 | if not add_aux_labels: 89 | self.targets.extend([-1] * len(aux_data)) 90 | else: 91 | self.targets.extend(aux_targets) 92 | self.unsup_indices.extend(range(orig_len, orig_len+len(aux_data))) 93 | 94 | else: 95 | self.sup_indices = list(range(len(self.targets))) 96 | self.unsup_indices = [] 97 | 98 | def load_base_dataset(self, **kwargs): 99 | raise NotImplementedError() 100 | 101 | @property 102 | def data(self): 103 | return self.dataset.data 104 | 105 | @data.setter 106 | def data(self, value): 107 | self.dataset.data = value 108 | 109 | @property 110 | def targets(self): 111 | return self.dataset.targets 112 | 113 | @targets.setter 114 | def targets(self, value): 115 | self.dataset.targets = value 116 | 117 | def __len__(self): 118 | return len(self.dataset) 119 | 120 | def __getitem__(self, item): 121 | self.dataset.labels = self.targets 122 | return self.dataset[item] 123 | 124 | 125 | class SemiSupervisedDatasetSVHN(torch.utils.data.Dataset): 126 | """ 127 | A dataset with auxiliary pseudo-labeled data. 128 | """ 129 | def __init__(self, base_dataset='svhn', take_amount=None, take_amount_seed=13, aux_data_filename=None, 130 | add_aux_labels=False, aux_take_amount=None, train=False, validation=False, **kwargs): 131 | 132 | self.base_dataset = base_dataset 133 | self.load_base_dataset(train, **kwargs) 134 | self.dataset.labels = self.dataset.labels.tolist() 135 | 136 | 137 | if validation: 138 | self.dataset.data = self.dataset.data[1024:] 139 | self.dataset.labels = self.dataset.labels[1024:] 140 | 141 | self.train = train 142 | 143 | if self.train: 144 | if take_amount is not None: 145 | rng_state = np.random.get_state() 146 | np.random.seed(take_amount_seed) 147 | take_inds = np.random.choice(len(self.sup_indices), take_amount, replace=False) 148 | np.random.set_state(rng_state) 149 | 150 | self.targets = self.targets[take_inds] 151 | self.data = self.data[take_inds] 152 | 153 | self.sup_indices = list(range(len(self.targets))) 154 | self.unsup_indices = [] 155 | 156 | if aux_data_filename is not None: 157 | aux_path = aux_data_filename 158 | print('Loading data from %s' % aux_path) 159 | if os.path.splitext(aux_path)[1] == '.pickle': 160 | # for data from Carmon et al, 2019. 161 | with open(aux_path, 'rb') as f: 162 | aux = pickle.load(f) 163 | aux_data = aux['data'] 164 | aux_targets = aux['extrapolated_targets'] 165 | else: 166 | # for data from Rebuffi et al, 2021. 167 | aux = np.load(aux_path) 168 | aux_data = aux['image'] 169 | print(aux_data.shape) 170 | aux_targets = aux['label'] 171 | 172 | orig_len = len(self.data) 173 | 174 | if aux_take_amount is not None: 175 | rng_state = np.random.get_state() 176 | np.random.seed(take_amount_seed) 177 | take_inds = np.random.choice(len(aux_data), aux_take_amount, replace=False) 178 | np.random.set_state(rng_state) 179 | 180 | aux_data = aux_data[take_inds] 181 | aux_targets = aux_targets[take_inds] 182 | 183 | self.data = np.concatenate((self.data, aux_data.transpose(0,3,1,2)), axis=0) 184 | 185 | if not add_aux_labels: 186 | self.targets.extend([-1] * len(aux_data)) 187 | else: 188 | self.targets.extend(aux_targets) 189 | self.unsup_indices.extend(range(orig_len, orig_len+len(aux_data))) 190 | 191 | else: 192 | self.sup_indices = list(range(len(self.targets))) 193 | self.unsup_indices = [] 194 | 195 | def load_base_dataset(self, **kwargs): 196 | raise NotImplementedError() 197 | 198 | @property 199 | def data(self): 200 | return self.dataset.data 201 | 202 | @data.setter 203 | def data(self, value): 204 | self.dataset.data = value 205 | 206 | @property 207 | def targets(self): 208 | return self.dataset.labels 209 | 210 | @targets.setter 211 | def targets(self, value): 212 | self.dataset.labels = value 213 | 214 | def __len__(self): 215 | return len(self.dataset) 216 | 217 | def __getitem__(self, item): 218 | self.dataset.labels = self.targets 219 | return self.dataset[item] 220 | 221 | 222 | class SemiSupervisedSampler(torch.utils.data.Sampler): 223 | """ 224 | Balanced sampling from the labeled and unlabeled data. 225 | """ 226 | def __init__(self, sup_inds, unsup_inds, batch_size, unsup_fraction=0.5, num_batches=None): 227 | if unsup_fraction is None or unsup_fraction < 0: 228 | self.sup_inds = sup_inds + unsup_inds 229 | unsup_fraction = 0.0 230 | else: 231 | self.sup_inds = sup_inds 232 | self.unsup_inds = unsup_inds 233 | 234 | self.batch_size = batch_size 235 | unsup_batch_size = int(batch_size * unsup_fraction) 236 | self.sup_batch_size = batch_size - unsup_batch_size 237 | 238 | if num_batches is not None: 239 | self.num_batches = num_batches 240 | else: 241 | self.num_batches = int(np.ceil(len(self.sup_inds) / self.sup_batch_size)) 242 | super().__init__(None) 243 | 244 | def __iter__(self): 245 | batch_counter = 0 246 | while batch_counter < self.num_batches: 247 | sup_inds_shuffled = [self.sup_inds[i] 248 | for i in torch.randperm(len(self.sup_inds))] 249 | for sup_k in range(0, len(self.sup_inds), self.sup_batch_size): 250 | if batch_counter == self.num_batches: 251 | break 252 | batch = sup_inds_shuffled[sup_k:(sup_k + self.sup_batch_size)] 253 | if self.sup_batch_size < self.batch_size: 254 | batch.extend([self.unsup_inds[i] for i in torch.randint(high=len(self.unsup_inds), 255 | size=(self.batch_size - len(batch),), 256 | dtype=torch.int64)]) 257 | np.random.shuffle(batch) 258 | yield batch 259 | batch_counter += 1 260 | 261 | def __len__(self): 262 | return self.num_batches -------------------------------------------------------------------------------- /data/svhn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | 7 | DATA_DESC = { 8 | 'data': 'svhn', 9 | 'classes': ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9'), 10 | 'num_classes': 10, 11 | 'mean': [0.4914, 0.4822, 0.4465], 12 | 'std': [0.2023, 0.1994, 0.2010], 13 | } 14 | 15 | 16 | def load_svhn(data_dir, use_augmentation='base'): 17 | """ 18 | Returns SVHN train, test datasets and dataloaders. 19 | Arguments: 20 | data_dir (str): path to data directory. 21 | use_augmentation (base/none): whether to use augmentations for training set. 22 | Returns: 23 | train dataset, test dataset. 24 | """ 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | train_transform = test_transform 27 | 28 | train_dataset = torchvision.datasets.SVHN(root=data_dir, split='train', download=True, transform=train_transform) 29 | test_dataset = torchvision.datasets.SVHN(root=data_dir, split='test', download=True, transform=test_transform) 30 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /data/svhns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | import re 7 | import numpy as np 8 | 9 | from .semisup import SemiSupervisedDataset, SemiSupervisedDatasetSVHN 10 | from .semisup import SemiSupervisedSampler 11 | 12 | 13 | def load_svhns(data_dir, use_augmentation='base', use_consistency=False, aux_take_amount=None, 14 | aux_data_filename='/cluster/scratch/rarade/svhns/ti_500K_pseudo_labeled.pickle', 15 | validation=False): 16 | """ 17 | Returns semisupervised SVHN train, test datasets and dataloaders (with Tiny Images). 18 | Arguments: 19 | data_dir (str): path to data directory. 20 | use_augmentation (base/none): whether to use augmentations for training set. 21 | aux_take_amount (int): number of semi-supervised examples to use (if None, use all). 22 | aux_data_filename (str): path to additional data pickle file. 23 | Returns: 24 | train dataset, test dataset. 25 | """ 26 | data_dir = re.sub('svhns', 'svhn', data_dir) 27 | test_transform = transforms.Compose([transforms.ToTensor()]) 28 | train_transform = test_transform 29 | 30 | train_dataset = SemiSupervisedSVHN(base_dataset='svhn', root=data_dir, train=True, download=True, 31 | transform=train_transform, aux_data_filename=aux_data_filename, 32 | add_aux_labels=True, aux_take_amount=aux_take_amount, validation=validation) 33 | test_dataset = SemiSupervisedSVHN(base_dataset='svhn', root=data_dir, train=False, download=True, 34 | transform=test_transform) 35 | if validation: 36 | val_dataset = torchvision.datasets.SVHN(root=data_dir, split='train', download=True, transform=test_transform) 37 | val_dataset = torch.utils.data.Subset(val_dataset, np.arange(0, 1024)) 38 | return train_dataset, test_dataset, val_dataset 39 | return train_dataset, test_dataset 40 | 41 | 42 | class SemiSupervisedSVHN(SemiSupervisedDatasetSVHN): 43 | """ 44 | A dataset with auxiliary pseudo-labeled data for SVHN. 45 | """ 46 | def load_base_dataset(self, train=False, **kwargs): 47 | assert self.base_dataset == 'svhn', 'Only semi-supervised SVHN is supported. Please use correct dataset!' 48 | if train: 49 | self.dataset = torchvision.datasets.SVHN(split='train', **kwargs) 50 | else: 51 | self.dataset = torchvision.datasets.SVHN(split='test', **kwargs) 52 | self.dataset_size = len(self.dataset) 53 | -------------------------------------------------------------------------------- /data/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from torchvision.datasets import ImageFolder 7 | 8 | 9 | DATA_DESC = { 10 | 'data': 'tiny-imagenet', 11 | 'classes': tuple(range(0, 200)), 12 | 'num_classes': 200, 13 | 'mean': [0.4802, 0.4481, 0.3975], 14 | 'std': [0.2302, 0.2265, 0.2262], 15 | } 16 | 17 | 18 | def load_tinyimagenet(data_dir, use_augmentation='base'): 19 | """ 20 | Returns Tiny Imagenet-200 train, test datasets and dataloaders. 21 | Arguments: 22 | data_dir (str): path to data directory. 23 | use_augmentation (base/none): whether to use augmentations for training set. 24 | Returns: 25 | train dataset, test dataset. 26 | """ 27 | test_transform = transforms.Compose([transforms.ToTensor()]) 28 | if use_augmentation == 'base': 29 | train_transform = transforms.Compose([transforms.RandomCrop(64, padding=4), transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor()]) 31 | else: 32 | train_transform = test_transform 33 | 34 | train_dataset = ImageFolder(os.path.join(data_dir, 'train'), transform=train_transform) 35 | test_dataset = ImageFolder(os.path.join(data_dir, 'val'), transform=test_transform) 36 | 37 | return train_dataset, test_dataset 38 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /main_aug.py: -------------------------------------------------------------------------------- 1 | # generate samples by using cDCGAN and StyleGAN-V2 2 | # and store them to npz files 3 | import argparse 4 | import os 5 | import numpy as np 6 | 7 | import torch 8 | from models.cdcgan import Generator as CDCGAN 9 | 10 | from utils import seed_torch 11 | from utils import legacy 12 | 13 | from tqdm import tqdm 14 | 15 | import warnings 16 | warnings.filterwarnings('ignore') 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--data-dir', type=str, default='./datasets') 20 | parser.add_argument('--desc', type=str, default='cDCGAN', 21 | help='Name of model. It will be used to name directories.') 22 | parser.add_argument('--sample_list', default=[1e5, 3e5, 5e5, 7e5, 1e6], type=list, nargs='+', 23 | help='sample list') 24 | parser.add_argument('--model', default='stylegan', type=str, choices=['CDCGAN', 'stylegan'], 25 | help='pretrained generative model.') 26 | parser.add_argument('--model-path', default='', type=str, help='path of pretrained generative model.') 27 | parser.add_argument("--seed", type=int, default=0, help="seed") 28 | parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") 29 | parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset") 30 | parser.add_argument("--channels", type=int, default=3, help="number of image channels") 31 | parser.add_argument("--width", type=int, default=128, help="number of feature maps") 32 | parser.add_argument("--batch_size", type=int, default=500, help="size of the batches") 33 | parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") 34 | 35 | def main(): 36 | args = parser.parse_args() 37 | print(args) 38 | 39 | args.store_dir = os.path.join(args.data_dir, args.desc) 40 | if os.path.exists(args.store_dir): 41 | raise RuntimeError('existing generated data, please check!') 42 | os.makedirs(args.store_dir) 43 | 44 | cuda = True if torch.cuda.is_available() else False 45 | args.device = 'cuda' if cuda else 'cup' 46 | seed_torch(args.seed) 47 | torch.backends.cudnn.benchmark = False 48 | torch.backends.cudnn.deterministic = True 49 | 50 | for sample_size in args.sample_list: 51 | if args.model == 'CDCGAN': 52 | sample_CDCGAN(sample_size, args) 53 | elif args.model == 'stylegan': 54 | sample_stylegan(sample_size, args) 55 | else: 56 | raise NotImplementedError('no such model!') 57 | 58 | def sample_CDCGAN(sample_size, args): 59 | model = CDCGAN( 60 | n_classes=args.n_classes, 61 | latent_dim=args.latent_dim, 62 | channels=args.channels, 63 | width=args.width 64 | ).to(args.device) 65 | 66 | state_dict = torch.load(args.model_path) 67 | model.load_state_dict(state_dict) 68 | for _, param in model.named_parameters(): 69 | param.requires_grad = False 70 | 71 | # label preprocess 72 | onehot = torch.zeros(args.n_classes, args.n_classes) 73 | onehot = onehot.scatter_(1, torch.LongTensor(list(range(args.n_classes))).view(args.n_classes,1), 1) 74 | onehot = onehot.view(args.n_classes, args.n_classes, 1, 1) 75 | 76 | sample_each_class = int(sample_size / args.n_classes) 77 | iter_each_class = int(sample_each_class / args.batch_size) 78 | 79 | total_imgs = [] 80 | total_labels = [] 81 | 82 | for c in tqdm(range(args.n_classes)): 83 | for i in range(iter_each_class): 84 | 85 | z = torch.randn((args.batch_size, args.latent_dim)).to(args.device).float() 86 | z = z.view(-1, args.latent_dim, 1, 1) 87 | 88 | gen_labels = np.array([c] * args.batch_size) 89 | gen_labels_onehot = onehot[gen_labels].to(args.device) 90 | 91 | with torch.no_grad(): 92 | gen_imgs = model(z, gen_labels_onehot) 93 | gen_imgs = gen_imgs * 0.5 + 0.5 94 | gen_imgs = gen_imgs.mul(255).add_(0.5).clamp_(0, 255).permute(0, 2, 3, 1) 95 | gen_imgs = gen_imgs.detach().to("cpu", torch.uint8).numpy() 96 | 97 | total_imgs.append(gen_imgs) 98 | total_labels.append(gen_labels) 99 | 100 | total_imgs = np.concatenate(total_imgs) 101 | total_labels = np.concatenate(total_labels) 102 | 103 | permutation = np.random.permutation(total_labels.shape[0]) 104 | total_imgs = total_imgs[permutation] 105 | total_labels = total_labels[permutation] 106 | 107 | file_name = str(int(sample_size / 1000)) + 'k' 108 | file_path = os.path.join(args.store_dir, file_name+'.npz') 109 | np.savez(file_path, image=total_imgs, label=total_labels) 110 | 111 | def sample_stylegan(sample_size, args): 112 | f = open(args.model_path, 'rb') 113 | model = legacy.load_network_pkl(f)['G_ema'].to(args.device) # type: ignore 114 | for _, param in model.named_parameters(): 115 | param.requires_grad = False 116 | # print(model) 117 | 118 | sample_each_class = int(sample_size / args.n_classes) 119 | iter_each_class = int(sample_each_class / args.batch_size) 120 | 121 | total_imgs = [] 122 | total_labels = [] 123 | 124 | for c in tqdm(range(args.n_classes)): 125 | for i in range(iter_each_class): 126 | # Labels. 127 | gen_labels = np.array([c] * args.batch_size) 128 | label = torch.zeros([args.batch_size, model.c_dim], device=args.device) 129 | label[:, c] = 1 130 | 131 | z = torch.randn(args.batch_size, model.z_dim).to(args.device) 132 | img = model(z, label, truncation_psi=1, noise_mode='random') 133 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 134 | img = img.detach().to("cpu", torch.uint8).numpy() 135 | 136 | total_imgs.append(img) 137 | total_labels.append(gen_labels) 138 | 139 | total_imgs = np.concatenate(total_imgs) 140 | total_labels = np.concatenate(total_labels) 141 | 142 | permutation = np.random.permutation(total_labels.shape[0]) 143 | total_imgs = total_imgs[permutation] 144 | total_labels = total_labels[permutation] 145 | 146 | file_name = str(int(sample_size / 1000)) + 'k' 147 | file_path = os.path.join(args.store_dir, file_name+'.npz') 148 | np.savez(file_path, image=total_imgs, label=total_labels) 149 | 150 | 151 | if __name__ == '__main__': 152 | main() -------------------------------------------------------------------------------- /main_train_CDCGAN.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | import shutil 5 | import json 6 | import time 7 | 8 | import numpy as np 9 | 10 | from models.cdcgan import Generator, Discriminator 11 | 12 | import torchvision.transforms as transforms 13 | 14 | from torch.utils.data import DataLoader 15 | from torchvision import datasets 16 | from torchvision.utils import save_image 17 | 18 | import torch 19 | 20 | from utils import Logger, seed_torch, format_time, weights_init_normal 21 | 22 | def sample_image_grid(generator, z, n_row, batches_done, img_dir): 23 | """Saves a grid of generated digits ranging from 0 to n_classes""" 24 | # Get labels ranging from 0 to n_classes for n rows 25 | labels = np.array([num for _ in range(n_row) for num in range(n_row)]) 26 | y_label = onehot[labels].to(device) 27 | # generator.eval() 28 | with torch.no_grad(): 29 | gen_imgs = generator(z, y_label) 30 | # generator.train() 31 | gen_imgs = gen_imgs * 0.5 + 0.5 32 | image_name = "%d.png" % batches_done 33 | image_path = os.path.join(img_dir, image_name) 34 | save_image(gen_imgs.data, image_path, nrow=n_row, normalize=True) 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--data_dir', type=str, default='./data/cifar') 38 | parser.add_argument('--log-dir', type=str, default='./log') 39 | parser.add_argument('--desc', type=str, default='cDCGAN', 40 | help='Description of experiment. It will be used to name directories.') 41 | parser.add_argument("--seed", type=int, default=0, help="seed") 42 | parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") 43 | parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") 44 | parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") 45 | parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") 46 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 47 | parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") 48 | parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset") 49 | parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") 50 | parser.add_argument("--channels", type=int, default=3, help="number of image channels") 51 | parser.add_argument("--width", type=int, default=128, help="number of feature maps") 52 | parser.add_argument("--sample_interval", type=int, default=10, help="interval between image sampling") 53 | args = parser.parse_args() 54 | print(args) 55 | 56 | LOG_DIR = os.path.join(args.log_dir, args.desc) 57 | IMAGE_DIR = os.path.join(LOG_DIR, "images") 58 | D_WEIGHTS = os.path.join(LOG_DIR, 'D-weights-last.pt') 59 | G_WEIGHTS = os.path.join(LOG_DIR, 'G-weights-last.pt') 60 | 61 | if os.path.exists(LOG_DIR): 62 | shutil.rmtree(LOG_DIR) 63 | os.makedirs(LOG_DIR) 64 | os.makedirs(IMAGE_DIR) 65 | logger = Logger(os.path.join(LOG_DIR, 'log-train.log')) 66 | 67 | with open(os.path.join(LOG_DIR, 'args.txt'), 'w') as f: 68 | json.dump(args.__dict__, f, indent=4) 69 | 70 | 71 | cuda = True if torch.cuda.is_available() else False 72 | device = 'cuda' if cuda else 'cup' 73 | logger.log('If use cuda: {}'.format(cuda)) 74 | 75 | seed_torch(args.seed) 76 | 77 | fixed_z = torch.randn((args.n_classes ** 2, args.latent_dim)).to(device).float() 78 | fixed_z = fixed_z.view((-1, args.latent_dim, 1, 1)) 79 | 80 | # Loss functions 81 | adversarial_loss = torch.nn.BCELoss().to(device) 82 | 83 | # Initialize generator and discriminator 84 | generator = Generator( 85 | n_classes=args.n_classes, 86 | latent_dim=args.latent_dim, 87 | channels=args.channels, 88 | width=args.width 89 | ).to(device) 90 | 91 | discriminator = Discriminator( 92 | n_classes=args.n_classes, 93 | channels=args.channels, 94 | width=args.width 95 | ).to(device) 96 | 97 | # Initialize weights 98 | generator.apply(weights_init_normal) 99 | discriminator.apply(weights_init_normal) 100 | 101 | # Configure data loader 102 | train_transform = transforms.Compose([ 103 | transforms.ToTensor(), 104 | transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5]) 105 | ]) 106 | dataloader = DataLoader( 107 | datasets.CIFAR10( 108 | args.data_dir, 109 | train=True, 110 | download=True, 111 | transform=train_transform 112 | ), 113 | batch_size=args.batch_size, 114 | shuffle=True, 115 | ) 116 | print(len(dataloader)) 117 | 118 | # Optimizers 119 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) 120 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) 121 | 122 | # label preprocess 123 | onehot = torch.zeros(10, 10) 124 | onehot = onehot.scatter_(1, torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).view(10,1), 1).view(10, 10, 1, 1) 125 | fill = torch.zeros([10, 10, args.img_size, args.img_size]) 126 | for i in range(args.n_classes): 127 | fill[i, i, :, :] = 1 128 | 129 | # ---------- 130 | # Training 131 | # ---------- 132 | 133 | logger.log('Standard training for {} epochs'.format(args.n_epochs)) 134 | 135 | for epoch in range(args.n_epochs): 136 | start = time.time() 137 | logger.log('======= Epoch {} ======='.format(epoch+1)) 138 | 139 | d_loss_avg = 0 140 | g_loss_avg = 0 141 | 142 | for i, (imgs, labels) in enumerate(dataloader): 143 | imgs, labels = imgs.to(device), labels.to(device) 144 | 145 | batch_size = imgs.shape[0] 146 | 147 | # Adversarial ground truths 148 | valid = torch.ones((batch_size), requires_grad=False).to(device) 149 | fake = torch.zeros((batch_size), requires_grad=False).to(device) 150 | 151 | # ----------------- 152 | # Train Generator 153 | # ----------------- 154 | 155 | optimizer_G.zero_grad() 156 | 157 | # Sample noise and labels as generator input 158 | z = torch.randn((batch_size, args.latent_dim)).to(device).float() 159 | z = z.view(-1, args.latent_dim, 1, 1) 160 | gen_labels = np.random.randint(0, args.n_classes, batch_size) 161 | gen_labels_onehot = onehot[gen_labels].to(device) 162 | gen_labels_fill = fill[gen_labels].to(device) 163 | 164 | # Generate a batch of images 165 | gen_imgs = generator(z, gen_labels_onehot) 166 | 167 | # Loss measures generator's ability to fool the discriminator 168 | validity = discriminator(gen_imgs, gen_labels_fill).squeeze() 169 | g_loss = adversarial_loss(validity, valid) 170 | g_loss_avg += g_loss.item() 171 | 172 | g_loss.backward() 173 | optimizer_G.step() 174 | 175 | # --------------------- 176 | # Train Discriminator 177 | # --------------------- 178 | 179 | optimizer_D.zero_grad() 180 | 181 | labels_fill = fill[labels].to(device) 182 | # Loss for real images 183 | real_pred = discriminator(imgs, labels_fill).squeeze() 184 | d_real_loss = adversarial_loss(real_pred, valid) 185 | 186 | # Loss for fake images 187 | fake_pred = discriminator(gen_imgs.detach(), gen_labels_fill).squeeze() 188 | d_fake_loss = adversarial_loss(fake_pred, fake) 189 | 190 | # Total discriminator loss 191 | d_loss = (d_real_loss + d_fake_loss) / 2 192 | d_loss_avg += d_loss.item() 193 | 194 | d_loss.backward() 195 | optimizer_D.step() 196 | 197 | d_loss_avg /= len(dataloader) 198 | g_loss_avg /= len(dataloader) 199 | 200 | logger.log( 201 | "[Epoch %d/%d] [D loss: %.3f] [G loss: %.3f]" 202 | % (epoch + 1, args.n_epochs, d_loss_avg, g_loss_avg) 203 | ) 204 | logger.log('Time taken: {}'.format(format_time(time.time()-start))) 205 | 206 | if (epoch + 1) % args.sample_interval == 0: 207 | sample_image_grid(generator, fixed_z, n_row=10, batches_done=epoch+1, img_dir=IMAGE_DIR) 208 | torch.save(generator.state_dict(), G_WEIGHTS) 209 | torch.save(discriminator.state_dict(), D_WEIGHTS) 210 | 211 | logger.log('\nTraining completed.') 212 | logger.log('Script Completed.') -------------------------------------------------------------------------------- /main_train_aug.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import argparse 4 | import shutil 5 | 6 | from tqdm import tqdm 7 | 8 | import os 9 | import numpy as np 10 | import pandas as pd 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from data import get_data_info 16 | from data import load_data 17 | from data import SEMISUP_DATASETS 18 | 19 | from utils import format_time 20 | from utils import Logger 21 | from utils import parser_train 22 | from utils import Trainer 23 | from utils import seed_torch 24 | 25 | from warnings import simplefilter 26 | simplefilter(action='ignore', category=FutureWarning) 27 | 28 | # Setup 29 | 30 | parse = parser_train() 31 | args = parse.parse_args() 32 | assert args.data in SEMISUP_DATASETS, f'Only data in {SEMISUP_DATASETS} is supported!' 33 | 34 | 35 | DATA_DIR = os.path.join(args.data_dir, args.data) 36 | LOG_DIR = os.path.join(args.log_dir, args.desc) 37 | WEIGHTS = os.path.join(LOG_DIR, 'weights-best.pt') 38 | if os.path.exists(LOG_DIR): 39 | shutil.rmtree(LOG_DIR) 40 | os.makedirs(LOG_DIR) 41 | logger = Logger(os.path.join(LOG_DIR, 'log-train.log')) 42 | 43 | with open(os.path.join(LOG_DIR, 'args.txt'), 'w') as f: 44 | json.dump(args.__dict__, f, indent=4) 45 | 46 | 47 | info = get_data_info(DATA_DIR) 48 | BATCH_SIZE = args.batch_size 49 | BATCH_SIZE_VALIDATION = args.batch_size_validation 50 | NUM_EPOCHS = args.num_epochs 51 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 52 | logger.log('Using device: {}'.format(device)) 53 | if args.debug: 54 | NUM_EPOCHS = 1 55 | 56 | # To speed up training and fix random seed 57 | seed_torch(args.seed) 58 | 59 | 60 | # Load data 61 | train_dataset, test_dataset, train_dataloader, test_dataloader = load_data( 62 | DATA_DIR, BATCH_SIZE, BATCH_SIZE_VALIDATION, use_augmentation=args.augment, use_consistency=args.consistency, shuffle_train=True, 63 | aux_data_filename=args.aux_data_filename, unsup_fraction=args.unsup_fraction, aux_take_amount=args.aux_take_amount, validation=False 64 | ) 65 | print(len(train_dataset)) 66 | del train_dataset, test_dataset 67 | 68 | trainer = Trainer(info, args) 69 | last_lr = args.lr 70 | 71 | 72 | if NUM_EPOCHS > 0: 73 | logger.log('\n\n') 74 | metrics = pd.DataFrame() 75 | logger.log('Standard Accuracy-\tEval: {:2f}%.'.format(trainer.eval(test_dataloader)[1]*100)) 76 | 77 | old_score = 0.0 78 | logger.log('Standard training for {} epochs'.format(NUM_EPOCHS)) 79 | trainer.init_optimizer(args.num_epochs) 80 | eval_acc = 0.0 81 | 82 | if args.resume_path: 83 | start_epoch = trainer.load_model_resume(os.path.join(args.resume_path, 'state-last.pt')) + 1 84 | logger.log(f'Resuming at epoch {start_epoch}') 85 | else: 86 | start_epoch = 1 87 | 88 | for epoch in tqdm(range(start_epoch, NUM_EPOCHS+1)): 89 | start = time.time() 90 | logger.log('======= Epoch {} ======='.format(epoch)) 91 | 92 | if args.scheduler: 93 | last_lr = trainer.scheduler.get_last_lr()[0] 94 | 95 | res = trainer.train(train_dataloader, epoch=epoch, adversarial=False) 96 | 97 | logger.log('Loss: {:.4f}.\tLR: {:.4f}'.format(res['loss'], last_lr)) 98 | logger.log('Mean Accuracy-\tTrain: {:.2f}%.'.format(res['clean_acc']*100)) 99 | 100 | epoch_metrics = {'train_'+k: v for k, v in res.items()} 101 | # epoch_metrics.update({'epoch': epoch, 'lr': last_lr, 'test_clean_acc': test_acc, 'test_adversarial_acc': ''}) 102 | epoch_metrics.update({'epoch': epoch, 'lr': last_lr}) 103 | 104 | if epoch % args.eval_freq == 0 or epoch == NUM_EPOCHS: 105 | train_loss, train_acc = trainer.eval(train_dataloader, adversarial=False) 106 | eval_loss, eval_acc = trainer.eval(test_dataloader, adversarial=False) 107 | logger.log('Loss-\tTrain: {:.4f}.\tTest: {:.4f}.\tGap: {:.4f}.'.format(train_loss, eval_loss, np.abs(train_loss-eval_loss))) 108 | logger.log('Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.\tGap: {:.2f}%.'.format(train_acc*100, eval_acc*100, np.abs(train_acc*100-eval_acc*100))) 109 | epoch_metrics.update({'train_acc': train_acc*100}) 110 | epoch_metrics.update({'eval_acc': eval_acc*100}) 111 | epoch_metrics.update({'train_loss2': train_loss}) 112 | epoch_metrics.update({'eval_loss2': eval_loss}) 113 | epoch_metrics.update({'acc_gap': np.abs(train_acc*100-eval_acc*100)}) 114 | epoch_metrics.update({'loss_gap': np.abs(train_loss-eval_loss)}) 115 | 116 | trainer.save_model(os.path.join(LOG_DIR, 'state-last.pt')) 117 | 118 | 119 | if eval_acc > old_score: 120 | old_score = eval_acc 121 | trainer.save_model(WEIGHTS) 122 | 123 | if epoch % NUM_EPOCHS == 0: 124 | shutil.copyfile(WEIGHTS, os.path.join(LOG_DIR, f'weights-best-epoch{str(epoch)}.pt')) 125 | 126 | logger.log('Time taken: {}'.format(format_time(time.time()-start))) 127 | metrics = metrics.append(pd.DataFrame(epoch_metrics, index=[0]), ignore_index=True) 128 | metrics.to_csv(os.path.join(LOG_DIR, 'stats.csv'), index=False) 129 | 130 | 131 | 132 | # Record metrics 133 | logger.log('\nTraining completed.') 134 | logger.log('Standard Accuracy-\tBest Test: {:.2f}%.'.format(old_score*100)) 135 | logger.log('Script Completed.') -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .resnet import Normalization 4 | from .preact_resnet import preact_resnet 5 | from .resnet import resnet 6 | from .wideresnet import wideresnet 7 | 8 | from .preact_resnetwithswish import preact_resnetwithswish 9 | from .wideresnetwithswish import wideresnetwithswish 10 | 11 | from data import DATASETS 12 | 13 | 14 | MODELS = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 15 | 'preact-resnet18', 'preact-resnet34', 'preact-resnet50', 'preact-resnet101', 16 | 'wrn-28-10', 'wrn-32-10', 'wrn-34-10', 'wrn-34-20', 17 | 'preact-resnet18-swish', 'preact-resnet34-swish', 18 | 'wrn-28-10-swish', 'wrn-34-20-swish', 'wrn-70-16-swish'] 19 | 20 | 21 | def create_model(name, normalize, info, device): 22 | """ 23 | Returns suitable model from its name. 24 | Arguments: 25 | name (str): name of resnet architecture. 26 | normalize (bool): normalize input. 27 | info (dict): dataset information. 28 | device (str or torch.device): device to work on. 29 | Returns: 30 | torch.nn.Module. 31 | """ 32 | if info['data'] in ['tiny-imagenet']: 33 | assert 'preact-resnet' in name, 'Only preact-resnets are supported for this dataset!' 34 | from .ti_preact_resnet import ti_preact_resnet 35 | backbone = ti_preact_resnet(name, num_classes=info['num_classes'], device=device) 36 | 37 | elif info['data'] in DATASETS and info['data'] not in ['tiny-imagenet']: 38 | if 'preact-resnet' in name and 'swish' not in name: 39 | backbone = preact_resnet(name, num_classes=info['num_classes'], pretrained=False, device=device) 40 | elif 'preact-resnet' in name and 'swish' in name: 41 | backbone = preact_resnetwithswish(name, dataset=info['data'], num_classes=info['num_classes']) 42 | elif 'resnet' in name and 'preact' not in name: 43 | backbone = resnet(name, num_classes=info['num_classes'], pretrained=False, device=device) 44 | elif 'wrn' in name and 'swish' not in name: 45 | backbone = wideresnet(name, num_classes=info['num_classes'], device=device) 46 | elif 'wrn' in name and 'swish' in name: 47 | backbone = wideresnetwithswish(name, dataset=info['data'], num_classes=info['num_classes'], device=device) 48 | else: 49 | raise ValueError('Invalid model name {}!'.format(name)) 50 | 51 | else: 52 | raise ValueError('Models for {} not yet supported!'.format(info['data'])) 53 | 54 | if normalize: 55 | model = torch.nn.Sequential(Normalization(info['mean'], info['std']), backbone) 56 | else: 57 | model = torch.nn.Sequential(backbone) 58 | 59 | model = torch.nn.DataParallel(model) 60 | model = model.to(device) 61 | return model 62 | -------------------------------------------------------------------------------- /models/cdcgan.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | # Generator Code 5 | 6 | class Generator(nn.Module): 7 | def __init__(self, n_classes=10, latent_dim=100, channels=3, width=64): 8 | super(Generator, self).__init__() 9 | 10 | 11 | self.layer_z = nn.Sequential( 12 | # input is Z, going into a convolution 13 | nn.ConvTranspose2d(latent_dim, width * 2, 4, 1, 0, bias=False), 14 | nn.BatchNorm2d(width * 2), 15 | nn.ReLU(True) 16 | ) 17 | 18 | self.layer_label = nn.Sequential( 19 | # input is label, going into a convolution 20 | nn.ConvTranspose2d(n_classes, width * 2, 4, 1, 0, bias=False), 21 | nn.BatchNorm2d(width * 2), 22 | nn.ReLU(True) 23 | ) 24 | 25 | self.layer = nn.Sequential( 26 | # state size. ``(ngf*4) x 4 x 4`` 27 | nn.ConvTranspose2d(width * 4, width * 2, 4, 2, 1, bias=False), 28 | nn.BatchNorm2d(width * 2), 29 | nn.ReLU(True), 30 | # state size. ``(ngf*2) x 8 x 8`` 31 | nn.ConvTranspose2d(width * 2, width, 4, 2, 1, bias=False), 32 | nn.BatchNorm2d(width), 33 | nn.ReLU(True), 34 | # state size. ``(ngf) x 16 x 16`` 35 | nn.ConvTranspose2d( width, channels, 4, 2, 1, bias=False), 36 | nn.Tanh() 37 | # state size. ``(nc) x 32 x 32`` 38 | ) 39 | 40 | def forward(self, noise, labels): 41 | x = self.layer_z(noise) 42 | # y = self.label_emb(labels) 43 | # y = y.view(y.shape[0], -1, 1, 1) 44 | y = self.layer_label(labels) 45 | x = torch.cat([x, y], 1) 46 | return self.layer(x) 47 | 48 | 49 | class Discriminator(nn.Module): 50 | def __init__(self, n_classes=10, channels=3, width=64): 51 | super(Discriminator, self).__init__() 52 | 53 | self.layer_img = nn.Sequential( 54 | # input is img, going into a convolution 55 | nn.Conv2d(channels, width // 2, 4, 2, 1, bias=False), 56 | nn.LeakyReLU(0.2, inplace=True) 57 | ) 58 | 59 | self.layer_label = nn.Sequential( 60 | # input is label, going into a convolution 61 | nn.Conv2d(n_classes, width // 2, 4, 2, 1, bias=False), 62 | nn.LeakyReLU(0.2, inplace=True) 63 | ) 64 | 65 | self.layer = nn.Sequential( 66 | # state size. ``(ndf) x 16 x 16`` 67 | nn.Conv2d(width, width * 2, 4, 2, 1, bias=False), 68 | nn.BatchNorm2d(width * 2), 69 | nn.LeakyReLU(0.2, inplace=True), 70 | # state size. ``(ndf*2) x 8 x 8`` 71 | nn.Conv2d(width * 2, width * 4, 4, 2, 1, bias=False), 72 | nn.BatchNorm2d(width * 4), 73 | nn.LeakyReLU(0.2, inplace=True), 74 | # state size. ``(ndf*4) x 4 x 4`` 75 | nn.Conv2d(width * 4, 1, 4, 1, 0, bias=False), 76 | nn.Sigmoid() 77 | ) 78 | 79 | def forward(self, img, labels): 80 | x = self.layer_img(img) 81 | y = self.layer_label(labels) 82 | x = torch.cat([x, y], 1) 83 | return self.layer(x) -------------------------------------------------------------------------------- /models/gaussnb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import pandas as pd 4 | from math import pi 5 | import sys 6 | import time 7 | 8 | class GaussianNB(): 9 | 10 | def __init__(self, val_epsilon=1e-9): 11 | self.val_epsilon = val_epsilon 12 | self.features = None 13 | self.K = None 14 | self.mus = None 15 | self.vars_groupby = None 16 | self.prior = None 17 | self.counts = None 18 | 19 | def get_vars(self, samples:pd.DataFrame): 20 | samples = samples.iloc[:,:-1] 21 | vars = np.var(samples, axis=0, ddof=1) + self.val_epsilon 22 | return vars 23 | 24 | def fit(self, X_train:np.ndarray, y_train:np.ndarray): 25 | X_train = pd.DataFrame(X_train) 26 | y_train = pd.DataFrame(y_train) 27 | if self.features is None: 28 | self.features = X_train.shape[1] 29 | if self.K is None: 30 | self.K = len(set(y_train.values.squeeze())) 31 | # print(self.K) 32 | 33 | train_set = pd.concat((X_train, y_train), axis=1, ignore_index=True) 34 | 35 | if self.vars_groupby is None and self.mus is None and self.prior is None and self.counts is None: 36 | grouped = train_set.groupby([train_set.shape[1]-1]) 37 | self.mus = grouped.agg('mean').values 38 | self.counts = grouped.agg('count').loc[:,0].values 39 | self.prior = self.counts / train_set.shape[0] 40 | self.vars_groupby = grouped.apply(self.get_vars).values 41 | 42 | else: 43 | grouped = train_set.groupby([train_set.shape[1]-1]) 44 | mus_new = grouped.agg('mean').values 45 | counts_new = grouped.agg('count').loc[:,0].values 46 | vars_groupby_new = grouped.apply(self.get_vars).values 47 | 48 | counts_total = self.counts + counts_new 49 | mus_total = (self.counts.reshape((-1,1)) * self.mus + counts_new.reshape((-1,1)) * mus_new) / counts_total.reshape((-1,1)) 50 | 51 | old_ssd = self.counts.reshape((-1,1)) * self.vars_groupby 52 | new_ssd = counts_new.reshape((-1,1)) * vars_groupby_new 53 | total_ssd = old_ssd + new_ssd + (counts_new * self.counts / counts_total).reshape((-1,1)) * (self.mus - mus_new) ** 2 54 | total_var_groupby = total_ssd / counts_total.reshape((-1,1)) 55 | 56 | self.mus = mus_total 57 | self.vars_groupby = total_var_groupby 58 | self.counts = counts_total 59 | self.prior = self.counts / np.sum(self.counts) 60 | 61 | 62 | def predict_gaussian_likehood(self, X_test: np.ndarray): 63 | vars = np.sum(self.vars_groupby * self.prior.reshape(-1,1), axis=0) 64 | joint_log_likelihood = [] 65 | # n_ij = 0 66 | for i in range(self.K): 67 | n_ij = -0.5 * np.sum(((X_test - self.mus[i, :]) ** 2) / (vars.reshape(1,-1)) + np.log(2*pi*vars.reshape(1,-1)), 1) 68 | joint_log_likelihood.append(n_ij) 69 | joint_log_likelihood = np.array(joint_log_likelihood).T 70 | 71 | return joint_log_likelihood 72 | 73 | def predict(self, X_test: np.ndarray): 74 | vars = np.sum(self.vars_groupby * self.prior.reshape(-1,1), axis=0) 75 | joint_log_likelihood = [] 76 | # n_ij = 0 77 | for i in range(self.K): 78 | jointi = np.log(self.prior[i]) 79 | n_ij = -0.5 * np.sum(((X_test - self.mus[i, :]) ** 2) / (vars.reshape(1,-1)), 1) 80 | joint_log_likelihood.append(jointi + n_ij) 81 | joint_log_likelihood = np.array(joint_log_likelihood).T 82 | 83 | return np.argmax(joint_log_likelihood, axis=1) 84 | 85 | def score(self, X_test, y_test): 86 | preds = self.predict(X_test) 87 | # print(preds) 88 | score = np.sum(preds == y_test) / len(y_test) 89 | return score 90 | 91 | 92 | class GaussianDA(): 93 | 94 | def __init__(self, val_epsilon=1e-9): 95 | self.val_epsilon = val_epsilon 96 | self.features = None 97 | self.K = None 98 | self.mus = None 99 | self.prior = None 100 | self.counts = None 101 | self.first_time = True 102 | self.covs = None 103 | 104 | def get_covs(self, samples:pd.DataFrame): 105 | samples = samples.iloc[:,:-1] 106 | covs = np.cov(samples, rowvar=False, ddof=0) 107 | return covs 108 | 109 | def fit(self, X_train:np.ndarray, y_train:np.ndarray): 110 | X_train = pd.DataFrame(X_train) 111 | y_train = pd.DataFrame(y_train) 112 | if self.features is None: 113 | self.features = X_train.shape[1] 114 | if self.K is None: 115 | self.K = len(set(y_train.values.squeeze())) 116 | # print(self.K) 117 | 118 | train_set = pd.concat((X_train, y_train), axis=1, ignore_index=True) 119 | 120 | if self.first_time == True: 121 | self.first_time = False 122 | grouped = train_set.groupby([train_set.shape[1]-1]) 123 | 124 | 125 | self.mus = grouped.agg('mean').values # K * n 126 | self.counts = grouped.agg('count').loc[:,0].values 127 | self.prior = self.counts / train_set.shape[0] 128 | covs_groupby = grouped.apply(self.get_covs).values 129 | 130 | self.covs = 0 131 | for k in range(self.prior.shape[0]): 132 | self.covs += covs_groupby[k] * self.prior[k] 133 | self.covs += np.eye(self.features) * self.val_epsilon 134 | 135 | 136 | else: 137 | print('false. no implementation') 138 | 139 | def predict(self, X_test: np.ndarray): 140 | joint_log_likelihood = [] 141 | inv = np.linalg.inv(self.covs) 142 | # n_ij = 0 143 | for i in range(self.K): 144 | jointi = np.log(self.prior[i]) 145 | # m * n * n * n * n * m = m * m --> m 146 | n_ij = -0.5 * np.diag((X_test - self.mus[i,:]) @ inv @ (X_test - self.mus[i,:]).T) 147 | joint_log_likelihood.append(jointi + n_ij) 148 | joint_log_likelihood = np.array(joint_log_likelihood).T 149 | return np.argmax(joint_log_likelihood, axis=1) 150 | 151 | def score(self, X_test, y_test): 152 | preds = self.predict(X_test) 153 | # print(preds) 154 | score = np.sum(preds == y_test) / len(y_test) 155 | return score 156 | 157 | class GaussianNB_puls_low_rank(): 158 | 159 | def __init__(self, val_epsilon=1e-9, rank=1): 160 | self.val_epsilon = val_epsilon 161 | self.features = None 162 | self.K = None 163 | self.mus = None 164 | self.prior = None 165 | self.counts = None 166 | self.first_time = True 167 | self.covs = None 168 | self.rank = rank 169 | 170 | def get_covs(self, samples:pd.DataFrame): 171 | samples = samples.iloc[:,:-1] 172 | covs = np.cov(samples, rowvar=False) + np.eye(samples.shape[1]) * self.val_epsilon 173 | return covs 174 | 175 | def get_vars(self, samples:pd.DataFrame): 176 | samples = samples.iloc[:,:-1] 177 | vars = np.var(samples, axis=0) + self.val_epsilon 178 | return vars 179 | 180 | def fit(self, X_train:np.ndarray, y_train:np.ndarray): 181 | X_train = pd.DataFrame(X_train) 182 | y_train = pd.DataFrame(y_train) 183 | if self.features is None: 184 | self.features = X_train.shape[1] 185 | if self.K is None: 186 | self.K = len(set(y_train.values.squeeze())) 187 | # print(self.K) 188 | 189 | train_set = pd.concat((X_train, y_train), axis=1, ignore_index=True) 190 | 191 | if self.first_time == True: 192 | self.first_time = False 193 | grouped = train_set.groupby([train_set.shape[1]-1]) 194 | 195 | # for group in grouped: 196 | # print(group) 197 | 198 | self.mus = grouped.agg('mean').values 199 | self.counts = grouped.agg('count').loc[:,0].values 200 | self.prior = self.counts / train_set.shape[0] 201 | 202 | covs_groupby = grouped.apply(self.get_covs).values 203 | covs = 0 204 | for k in range(self.prior.shape[0]): 205 | covs += covs_groupby[k] * self.prior[k] 206 | 207 | vars_groupby = grouped.apply(self.get_vars).values 208 | vars = np.sum(vars_groupby * self.prior.reshape(-1,1), axis=0) 209 | vars = np.diag(vars) 210 | 211 | low_rank = covs - vars 212 | u, s, vh = np.linalg.svd(low_rank) 213 | s[self.rank:] = 0 214 | low_rank = u @ np.diag(s) @ vh 215 | 216 | self.covs = vars + low_rank 217 | 218 | 219 | else: 220 | print('false. no implementation') 221 | 222 | def predict(self, X_test: np.ndarray): 223 | joint_log_likelihood = [] 224 | inv = np.linalg.inv(self.covs) 225 | # n_ij = 0 226 | for i in range(self.K): 227 | jointi = np.log(self.prior[i]) 228 | n_ij = -0.5 * np.diag((X_test - self.mus[i,:]) @ inv @ (X_test - self.mus[i,:]).T) 229 | joint_log_likelihood.append(jointi + n_ij) 230 | joint_log_likelihood = np.array(joint_log_likelihood).T 231 | return np.argmax(joint_log_likelihood, axis=1) 232 | 233 | def score(self, X_test, y_test): 234 | preds = self.predict(X_test) 235 | # print(preds) 236 | score = np.sum(preds == y_test) / len(y_test) 237 | return score 238 | 239 | def main(): 240 | from sklearn.datasets import make_blobs, make_classification 241 | from sklearn.model_selection import train_test_split 242 | features = 5 243 | X, y = make_blobs(n_samples=100000, n_features=features, centers=[features*[-1], features*[1]], 244 | cluster_std=[np.sqrt(features), 0.01*np.sqrt(features)], random_state=0) 245 | X_train, X_test, y_train, y_test = train_test_split(X,y, train_size=90000) 246 | 247 | X_train_1, X_train_2, y_train_1, y_train_2 = train_test_split(X_train, y_train, train_size=0.5) 248 | 249 | t1 = time.time() 250 | nb = GaussianNB() 251 | 252 | 253 | nb.fit(X_train_1, y_train_1) 254 | print(nb.vars_groupby) 255 | nb.fit(X_train_2, y_train_2) 256 | print(nb.vars_groupby) 257 | print(nb.score(X_test, y_test)) 258 | 259 | if __name__ == '__main__': 260 | main() -------------------------------------------------------------------------------- /models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock(nn.Module): 7 | """ 8 | Pre-activation version of the BasicBlock for Resnets. 9 | Arguments: 10 | in_planes (int): number of input planes. 11 | planes (int): number of output filters. 12 | stride (int): stride of convolution. 13 | """ 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | """ 39 | Pre-activation version of the original Bottleneck module for Resnets. 40 | Arguments: 41 | in_planes (int): number of input planes. 42 | planes (int): number of output filters. 43 | stride (int): stride of convolution. 44 | """ 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(PreActBottleneck, self).__init__() 49 | self.bn1 = nn.BatchNorm2d(in_planes) 50 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(planes) 54 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 55 | 56 | if stride != 1 or in_planes != self.expansion*planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(x)) 63 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 64 | out = self.conv1(out) 65 | out = self.conv2(F.relu(self.bn2(out))) 66 | out = self.conv3(F.relu(self.bn3(out))) 67 | out += shortcut 68 | return out 69 | 70 | 71 | class PreActResNet(nn.Module): 72 | """ 73 | Pre-activation Resnet model 74 | """ 75 | def __init__(self, block, num_blocks, num_classes=10): 76 | super(PreActResNet, self).__init__() 77 | self.in_planes = 64 78 | 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.bn = nn.BatchNorm2d(512 * block.expansion) 85 | self.linear = nn.Linear(512*block.expansion, num_classes) 86 | 87 | def _make_layer(self, block, planes, num_blocks, stride): 88 | strides = [stride] + [1]*(num_blocks-1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_planes, planes, stride)) 92 | self.in_planes = planes * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = self.conv1(x) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = self.layer4(out) 101 | out = F.relu(self.bn(out)) 102 | out = F.avg_pool2d(out, 4) 103 | out = out.view(out.size(0), -1) 104 | out = self.linear(out) 105 | return out 106 | 107 | 108 | def preact_resnet(name, num_classes=10, pretrained=False, device='cpu'): 109 | """ 110 | Returns suitable Resnet model from its name. 111 | Arguments: 112 | name (str): name of resnet architecture. 113 | num_classes (int): number of target classes. 114 | pretrained (bool): whether to use a pretrained model. 115 | device (str or torch.device): device to work on. 116 | Returns: 117 | torch.nn.Module. 118 | """ 119 | if name == 'preact-resnet18': 120 | return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes) 121 | elif name == 'preact-resnet34': 122 | return PreActResNet(PreActBlock, [3, 4, 6, 3], num_classes=num_classes) 123 | elif name == 'preact-resnet50': 124 | return PreActResNet(PreActBottleneck, [3, 4, 6, 3], num_classes=num_classes) 125 | elif name == 'preact-resnet101': 126 | return PreActResNet(PreActBottleneck, [3, 4, 23, 3], num_classes=num_classes) 127 | raise ValueError('Only preact-resnet18, preact-resnet34, preact-resnet50 and preact-resnet101 are supported!') 128 | return 129 | -------------------------------------------------------------------------------- /models/preact_resnetwithswish.py: -------------------------------------------------------------------------------- 1 | # Code borrowed from https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/pytorch/model_zoo.py 2 | # (Rebuffi et al 2021) 3 | 4 | from typing import Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) 12 | CIFAR10_STD = (0.2471, 0.2435, 0.2616) 13 | CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) 14 | CIFAR100_STD = (0.2673, 0.2564, 0.2762) 15 | SVHN_MEAN = (0.5, 0.5, 0.5) 16 | SVHN_STD = (0.5, 0.5, 0.5) 17 | 18 | _ACTIVATION = { 19 | 'relu': nn.ReLU, 20 | 'swish': nn.SiLU, 21 | } 22 | 23 | 24 | class _PreActBlock(nn.Module): 25 | """ 26 | PreAct ResNet Block. 27 | Arguments: 28 | in_planes (int): number of input planes. 29 | out_planes (int): number of output filters. 30 | stride (int): stride of convolution. 31 | activation_fn (nn.Module): activation function. 32 | """ 33 | def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): 34 | super().__init__() 35 | self._stride = stride 36 | self.batchnorm_0 = nn.BatchNorm2d(in_planes, momentum=0.01) 37 | self.relu_0 = activation_fn() 38 | # We manually pad to obtain the same effect as `SAME` (necessary when 39 | # `stride` is different than 1). 40 | self.conv_2d_1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, 41 | stride=stride, padding=0, bias=False) 42 | self.batchnorm_1 = nn.BatchNorm2d(out_planes, momentum=0.01) 43 | self.relu_1 = activation_fn() 44 | self.conv_2d_2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 45 | padding=1, bias=False) 46 | self.has_shortcut = stride != 1 or in_planes != out_planes 47 | if self.has_shortcut: 48 | self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=3, 49 | stride=stride, padding=0, bias=False) 50 | 51 | def _pad(self, x): 52 | if self._stride == 1: 53 | x = F.pad(x, (1, 1, 1, 1)) 54 | elif self._stride == 2: 55 | x = F.pad(x, (0, 1, 0, 1)) 56 | else: 57 | raise ValueError('Unsupported `stride`.') 58 | return x 59 | 60 | def forward(self, x): 61 | out = self.relu_0(self.batchnorm_0(x)) 62 | shortcut = self.shortcut(self._pad(x)) if self.has_shortcut else x 63 | out = self.conv_2d_1(self._pad(out)) 64 | out = self.conv_2d_2(self.relu_1(self.batchnorm_1(out))) 65 | return out + shortcut 66 | 67 | 68 | class PreActResNet(nn.Module): 69 | """ 70 | PreActResNet model 71 | Arguments: 72 | num_classes (int): number of output classes. 73 | depth (int): number of layers. 74 | width (int): width factor. 75 | activation_fn (nn.Module): activation function. 76 | mean (tuple): mean of dataset. 77 | std (tuple): standard deviation of dataset. 78 | padding (int): padding. 79 | num_input_channels (int): number of channels in the input. 80 | """ 81 | 82 | def __init__(self, 83 | num_classes: int = 10, 84 | depth: int = 18, 85 | width: int = 0, # Used to make the constructor consistent. 86 | activation_fn: nn.Module = nn.ReLU, 87 | mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, 88 | std: Union[Tuple[float, ...], float] = CIFAR10_STD, 89 | padding: int = 0, 90 | num_input_channels: int = 3): 91 | 92 | super().__init__() 93 | if width != 0: 94 | raise ValueError('Unsupported `width`.') 95 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) 96 | self.std = torch.tensor(std).view(num_input_channels, 1, 1) 97 | self.mean_cuda = None 98 | self.std_cuda = None 99 | self.padding = padding 100 | self.conv_2d = nn.Conv2d(num_input_channels, 64, kernel_size=3, stride=1, 101 | padding=1, bias=False) 102 | if depth == 18: 103 | num_blocks = (2, 2, 2, 2) 104 | elif depth == 34: 105 | num_blocks = (3, 4, 6, 3) 106 | else: 107 | raise ValueError('Unsupported `depth`.') 108 | self.layer_0 = self._make_layer(64, 64, num_blocks[0], 1, activation_fn) 109 | self.layer_1 = self._make_layer(64, 128, num_blocks[1], 2, activation_fn) 110 | self.layer_2 = self._make_layer(128, 256, num_blocks[2], 2, activation_fn) 111 | self.layer_3 = self._make_layer(256, 512, num_blocks[3], 2, activation_fn) 112 | self.batchnorm = nn.BatchNorm2d(512, momentum=0.01) 113 | self.relu = activation_fn() 114 | self.logits = nn.Linear(512, num_classes) 115 | 116 | def _make_layer(self, in_planes, out_planes, num_blocks, stride, 117 | activation_fn): 118 | layers = [] 119 | for i, stride in enumerate([stride] + [1] * (num_blocks - 1)): 120 | layers.append(_PreActBlock(i == 0 and in_planes or out_planes, 121 | out_planes, 122 | stride, 123 | activation_fn)) 124 | return nn.Sequential(*layers) 125 | 126 | def forward(self, x): 127 | if self.padding > 0: 128 | x = F.pad(x, (self.padding,) * 4) 129 | if x.is_cuda: 130 | if self.mean_cuda is None: 131 | self.mean_cuda = self.mean.cuda() 132 | self.std_cuda = self.std.cuda() 133 | out = (x - self.mean_cuda) / self.std_cuda 134 | else: 135 | out = (x - self.mean) / self.std 136 | out = self.conv_2d(out) 137 | out = self.layer_0(out) 138 | out = self.layer_1(out) 139 | out = self.layer_2(out) 140 | out = self.layer_3(out) 141 | out = self.relu(self.batchnorm(out)) 142 | out = F.avg_pool2d(out, 4) 143 | out = out.view(out.size(0), -1) 144 | return self.logits(out) 145 | 146 | 147 | def preact_resnetwithswish(name, dataset='cifar10', num_classes=10): 148 | """ 149 | Returns suitable PreActResNet model with Swish activation function from its name. 150 | Arguments: 151 | name (str): name of resnet architecture. 152 | num_classes (int): number of target classes. 153 | dataset (str): dataset to use. 154 | Returns: 155 | torch.nn.Module. 156 | """ 157 | name_parts = name.split('-') 158 | name = '-'.join(name_parts[:-1]) 159 | act_fn = name_parts[-1] 160 | depth = int(name[-2:]) 161 | 162 | if 'cifar100' in dataset: 163 | return PreActResNet(num_classes=num_classes, depth=depth, width=0, activation_fn=_ACTIVATION[act_fn], 164 | mean=CIFAR100_MEAN, std=CIFAR100_STD) 165 | elif 'svhn' in dataset: 166 | return PreActResNet(num_classes=num_classes, depth=depth, width=0, activation_fn=_ACTIVATION[act_fn], 167 | mean=SVHN_MEAN, std=SVHN_STD) 168 | return PreActResNet(num_classes=num_classes, depth=depth, width=0, activation_fn=_ACTIVATION[act_fn]) 169 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Normalization(nn.Module): 7 | """ 8 | Standardizes the input data. 9 | Arguments: 10 | mean (list): mean. 11 | std (float): standard deviation. 12 | device (str or torch.device): device to be used. 13 | Returns: 14 | (input - mean) / std 15 | """ 16 | def __init__(self, mean, std): 17 | super(Normalization, self).__init__() 18 | num_channels = len(mean) 19 | self.mean = torch.FloatTensor(mean).view(1, num_channels, 1, 1) 20 | self.sigma = torch.FloatTensor(std).view(1, num_channels, 1, 1) 21 | self.mean_cuda, self.sigma_cuda = None, None 22 | 23 | def forward(self, x): 24 | if x.is_cuda: 25 | if self.mean_cuda is None: 26 | self.mean_cuda = self.mean.cuda() 27 | self.sigma_cuda = self.sigma.cuda() 28 | out = (x - self.mean_cuda) / self.sigma_cuda 29 | else: 30 | out = (x - self.mean) / self.sigma 31 | return out 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | """ 36 | Implements a basic block module for Resnets. 37 | Arguments: 38 | in_planes (int): number of input planes. 39 | out_planes (int): number of output filters. 40 | stride (int): stride of convolution. 41 | """ 42 | expansion = 1 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(BasicBlock, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion * planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion * planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = self.bn2(self.conv2(out)) 61 | out += self.shortcut(x) 62 | out = F.relu(out) 63 | return out 64 | 65 | 66 | class Bottleneck(nn.Module): 67 | """ 68 | Implements a basic block module with bottleneck for Resnets. 69 | Arguments: 70 | in_planes (int): number of input planes. 71 | out_planes (int): number of output filters. 72 | stride (int): stride of convolution. 73 | """ 74 | expansion = 4 75 | 76 | def __init__(self, in_planes, planes, stride=1): 77 | super(Bottleneck, self).__init__() 78 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 79 | self.bn1 = nn.BatchNorm2d(planes) 80 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 81 | self.bn2 = nn.BatchNorm2d(planes) 82 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 83 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 84 | 85 | self.shortcut = nn.Sequential() 86 | if stride != 1 or in_planes != self.expansion * planes: 87 | self.shortcut = nn.Sequential( 88 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 89 | nn.BatchNorm2d(self.expansion * planes) 90 | ) 91 | 92 | def forward(self, x): 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = F.relu(self.bn2(self.conv2(out))) 95 | out = self.bn3(self.conv3(out)) 96 | out += self.shortcut(x) 97 | out = F.relu(out) 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | """ 103 | ResNet model 104 | Arguments: 105 | block (BasicBlock or Bottleneck): type of basic block to be used. 106 | num_blocks (list): number of blocks in each sub-module. 107 | num_classes (int): number of output classes. 108 | device (torch.device or str): device to work on. 109 | """ 110 | def __init__(self, block, num_blocks, num_classes=10, device='cpu'): 111 | super(ResNet, self).__init__() 112 | self.in_planes = 64 113 | 114 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 115 | self.bn1 = nn.BatchNorm2d(64) 116 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 117 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 118 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 119 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 120 | self.linear = nn.Linear(512 * block.expansion, num_classes) 121 | 122 | def _make_layer(self, block, planes, num_blocks, stride): 123 | strides = [stride] + [1] * (num_blocks - 1) 124 | layers = [] 125 | for stride in strides: 126 | layers.append(block(self.in_planes, planes, stride)) 127 | self.in_planes = planes * block.expansion 128 | return nn.Sequential(*layers) 129 | 130 | def forward(self, x): 131 | out = F.relu(self.bn1(self.conv1(x))) 132 | out = self.layer1(out) 133 | out = self.layer2(out) 134 | out = self.layer3(out) 135 | out = self.layer4(out) 136 | out = F.avg_pool2d(out, 4) 137 | out = out.view(out.size(0), -1) 138 | out = self.linear(out) 139 | return out 140 | 141 | 142 | def resnet(name, num_classes=10, pretrained=False, device='cpu'): 143 | """ 144 | Returns suitable Resnet model from its name. 145 | Arguments: 146 | name (str): name of resnet architecture. 147 | num_classes (int): number of target classes. 148 | pretrained (bool): whether to use a pretrained model. 149 | device (str or torch.device): device to work on. 150 | Returns: 151 | torch.nn.Module. 152 | """ 153 | if name == 'resnet18': 154 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, device=device) 155 | elif name == 'resnet34': 156 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, device=device) 157 | elif name == 'resnet50': 158 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, device=device) 159 | elif name == 'resnet101': 160 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, device=device) 161 | 162 | raise ValueError('Only resnet18, resnet34, resnet50 and resnet101 are supported!') 163 | return 164 | -------------------------------------------------------------------------------- /models/ti_preact_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock(nn.Module): 7 | """ 8 | Pre-activation version of the BasicBlock. 9 | Arguments: 10 | in_planes (int): number of input planes. 11 | planes (int): number of output filters. 12 | stride (int): stride of convolution. 13 | """ 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | """ 39 | Pre-activation version of the original Bottleneck module. 40 | Arguments: 41 | in_planes (int): number of input planes. 42 | planes (int): number of output filters. 43 | stride (int): stride of convolution. 44 | """ 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(PreActBottleneck, self).__init__() 49 | self.bn1 = nn.BatchNorm2d(in_planes) 50 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(planes) 54 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 55 | 56 | if stride != 1 or in_planes != self.expansion*planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(x)) 63 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 64 | out = self.conv1(out) 65 | out = self.conv2(F.relu(self.bn2(out))) 66 | out = self.conv3(F.relu(self.bn3(out))) 67 | out += shortcut 68 | return out 69 | 70 | 71 | class PreActResNet(nn.Module): 72 | """ 73 | Pre-activation Resnet model for TI-200 dataset. 74 | """ 75 | def __init__(self, block, num_blocks, num_classes=200): 76 | super(PreActResNet, self).__init__() 77 | self.in_planes = 64 78 | 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.bn = nn.BatchNorm2d(512 * block.expansion) 85 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 86 | self.linear = nn.Linear(512*block.expansion, num_classes) 87 | 88 | def _make_layer(self, block, planes, num_blocks, stride): 89 | strides = [stride] + [1]*(num_blocks-1) 90 | layers = [] 91 | for stride in strides: 92 | layers.append(block(self.in_planes, planes, stride)) 93 | self.in_planes = planes * block.expansion 94 | return nn.Sequential(*layers) 95 | 96 | def forward(self, x): 97 | out = self.conv1(x) 98 | out = self.layer1(out) 99 | out = self.layer2(out) 100 | out = self.layer3(out) 101 | out = self.layer4(out) 102 | out = F.relu(self.bn(out)) 103 | out = self.avgpool(out) 104 | out = out.view(out.size(0), -1) 105 | out = self.linear(out) 106 | return out 107 | 108 | 109 | def ti_preact_resnet(name, num_classes=200, pretrained=False, device='cpu'): 110 | """ 111 | Returns suitable PreAct Resnet model from its name (only for TI-200 dataset). 112 | Arguments: 113 | name (str): name of resnet architecture. 114 | num_classes (int): number of target classes. 115 | pretrained (bool): whether to use a pretrained model. 116 | device (str or torch.device): device to work on. 117 | Returns: 118 | torch.nn.Module. 119 | """ 120 | if name == 'preact-resnet18': 121 | return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes) 122 | elif name == 'preact-resnet34': 123 | return PreActResNet(PreActBlock, [3, 4, 6, 3], num_classes=num_classes) 124 | elif name == 'preact-resnet50': 125 | return PreActResNet(PreActBottleneck, [3, 4, 6, 3], num_classes=num_classes) 126 | elif name == 'preact-resnet101': 127 | return PreActResNet(PreActBottleneck, [3, 4, 23, 3], num_classes=num_classes) 128 | else: 129 | raise ValueError('Only preact-resnet18, preact-resnet34, preact-resnet50 and preact-resnet101 are supported!') 130 | return 131 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | """ 9 | Implements a basic block module for WideResNets. 10 | Arguments: 11 | in_planes (int): number of input planes. 12 | out_planes (int): number of output filters. 13 | stride (int): stride of convolution. 14 | dropRate (float): dropout rate. 15 | """ 16 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 17 | super(BasicBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(out_planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 25 | padding=1, bias=False) 26 | self.droprate = dropRate 27 | self.equalInOut = (in_planes == out_planes) 28 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 29 | padding=0, bias=False) or None 30 | 31 | def forward(self, x): 32 | if not self.equalInOut: 33 | x = self.relu1(self.bn1(x)) 34 | else: 35 | out = self.relu1(self.bn1(x)) 36 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 37 | if self.droprate > 0: 38 | out = F.dropout(out, p=self.droprate, training=self.training) 39 | out = self.conv2(out) 40 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 41 | 42 | 43 | class NetworkBlock(nn.Module): 44 | """ 45 | Implements a network block module for WideResnets. 46 | Arguments: 47 | nb_layers (int): number of layers. 48 | in_planes (int): number of input planes. 49 | out_planes (int): number of output filters. 50 | block (BasicBlock): type of basic block to be used. 51 | stride (int): stride of convolution. 52 | dropRate (float): dropout rate. 53 | """ 54 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 55 | super(NetworkBlock, self).__init__() 56 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 57 | 58 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 59 | layers = [] 60 | for i in range(int(nb_layers)): 61 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | return self.layer(x) 66 | 67 | 68 | class WideResNet(nn.Module): 69 | """ 70 | WideResNet model 71 | Arguments: 72 | depth (int): number of layers. 73 | num_classes (int): number of output classes. 74 | widen_factor (int): width factor. 75 | dropRate (float): dropout rate. 76 | """ 77 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0.0): 78 | super(WideResNet, self).__init__() 79 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 80 | assert ((depth - 4) % 6 == 0) 81 | n = (depth - 4) / 6 82 | block = BasicBlock 83 | # 1st conv before any network block 84 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 85 | padding=1, bias=False) 86 | # 1st block 87 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 88 | # 2nd block 89 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 90 | # 3rd block 91 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 92 | # global average pooling and classifier 93 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.fc = nn.Linear(nChannels[3], num_classes) 96 | self.nChannels = nChannels[3] 97 | 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 101 | m.weight.data.normal_(0, math.sqrt(2. / n)) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | elif isinstance(m, nn.Linear): 106 | m.bias.data.zero_() 107 | 108 | def forward(self, x): 109 | out = self.conv1(x) 110 | out = self.block1(out) 111 | out = self.block2(out) 112 | out = self.block3(out) 113 | out = self.relu(self.bn1(out)) 114 | out = F.avg_pool2d(out, 8) 115 | out = out.view(-1, self.nChannels) 116 | return self.fc(out) 117 | 118 | 119 | def wideresnet(name, num_classes=10, device='cpu'): 120 | """ 121 | Returns suitable Wideresnet model from its name. 122 | Arguments: 123 | name (str): name of resnet architecture. 124 | num_classes (int): number of target classes. 125 | device (str or torch.device): device to work on. 126 | Returns: 127 | torch.nn.Module. 128 | """ 129 | name_parts = name.split('-') 130 | depth = int(name_parts[1]) 131 | widen = int(name_parts[2]) 132 | return WideResNet(depth=depth, num_classes=num_classes, widen_factor=widen) 133 | -------------------------------------------------------------------------------- /models/wideresnetwithswish.py: -------------------------------------------------------------------------------- 1 | # Code borrowed from https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/pytorch/model_zoo.py 2 | # (Gowal et al 2020) 3 | 4 | from typing import Tuple, Union 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) 13 | CIFAR10_STD = (0.2471, 0.2435, 0.2616) 14 | CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) 15 | CIFAR100_STD = (0.2673, 0.2564, 0.2762) 16 | SVHN_MEAN = (0.5, 0.5, 0.5) 17 | SVHN_STD = (0.5, 0.5, 0.5) 18 | 19 | _ACTIVATION = { 20 | 'relu': nn.ReLU, 21 | 'swish': nn.SiLU, 22 | } 23 | 24 | 25 | class _Block(nn.Module): 26 | """ 27 | WideResNet Block. 28 | Arguments: 29 | in_planes (int): number of input planes. 30 | out_planes (int): number of output filters. 31 | stride (int): stride of convolution. 32 | activation_fn (nn.Module): activation function. 33 | """ 34 | def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): 35 | super().__init__() 36 | self.batchnorm_0 = nn.BatchNorm2d(in_planes, momentum=0.01) 37 | self.relu_0 = activation_fn(inplace=True) 38 | # We manually pad to obtain the same effect as `SAME` (necessary when `stride` is different than 1). 39 | self.conv_0 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 40 | padding=0, bias=False) 41 | self.batchnorm_1 = nn.BatchNorm2d(out_planes, momentum=0.01) 42 | self.relu_1 = activation_fn(inplace=True) 43 | self.conv_1 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 44 | padding=1, bias=False) 45 | self.has_shortcut = in_planes != out_planes 46 | if self.has_shortcut: 47 | self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=1, 48 | stride=stride, padding=0, bias=False) 49 | else: 50 | self.shortcut = None 51 | self._stride = stride 52 | 53 | def forward(self, x): 54 | if self.has_shortcut: 55 | x = self.relu_0(self.batchnorm_0(x)) 56 | else: 57 | out = self.relu_0(self.batchnorm_0(x)) 58 | v = x if self.has_shortcut else out 59 | if self._stride == 1: 60 | v = F.pad(v, (1, 1, 1, 1)) 61 | elif self._stride == 2: 62 | v = F.pad(v, (0, 1, 0, 1)) 63 | else: 64 | raise ValueError('Unsupported `stride`.') 65 | out = self.conv_0(v) 66 | out = self.relu_1(self.batchnorm_1(out)) 67 | out = self.conv_1(out) 68 | out = torch.add(self.shortcut(x) if self.has_shortcut else x, out) 69 | return out 70 | 71 | 72 | class _BlockGroup(nn.Module): 73 | """ 74 | WideResNet block group. 75 | Arguments: 76 | in_planes (int): number of input planes. 77 | out_planes (int): number of output filters. 78 | stride (int): stride of convolution. 79 | activation_fn (nn.Module): activation function. 80 | """ 81 | def __init__(self, num_blocks, in_planes, out_planes, stride, activation_fn=nn.ReLU): 82 | super().__init__() 83 | block = [] 84 | for i in range(num_blocks): 85 | block.append( 86 | _Block(i == 0 and in_planes or out_planes, 87 | out_planes, 88 | i == 0 and stride or 1, 89 | activation_fn=activation_fn) 90 | ) 91 | self.block = nn.Sequential(*block) 92 | 93 | def forward(self, x): 94 | return self.block(x) 95 | 96 | 97 | class WideResNet(nn.Module): 98 | """ 99 | WideResNet model 100 | Arguments: 101 | num_classes (int): number of output classes. 102 | depth (int): number of layers. 103 | width (int): width factor. 104 | activation_fn (nn.Module): activation function. 105 | mean (tuple): mean of dataset. 106 | std (tuple): standard deviation of dataset. 107 | padding (int): padding. 108 | num_input_channels (int): number of channels in the input. 109 | """ 110 | def __init__(self, 111 | num_classes: int = 10, 112 | depth: int = 28, 113 | width: int = 10, 114 | activation_fn: nn.Module = nn.ReLU, 115 | mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, 116 | std: Union[Tuple[float, ...], float] = CIFAR10_STD, 117 | padding: int = 0, 118 | num_input_channels: int = 3): 119 | super().__init__() 120 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) 121 | self.std = torch.tensor(std).view(num_input_channels, 1, 1) 122 | self.mean_cuda = None 123 | self.std_cuda = None 124 | self.padding = padding 125 | num_channels = [16, 16 * width, 32 * width, 64 * width] 126 | assert (depth - 4) % 6 == 0 127 | num_blocks = (depth - 4) // 6 128 | self.init_conv = nn.Conv2d(num_input_channels, num_channels[0], 129 | kernel_size=3, stride=1, padding=1, bias=False) 130 | self.layer = nn.Sequential( 131 | _BlockGroup(num_blocks, num_channels[0], num_channels[1], 1, 132 | activation_fn=activation_fn), 133 | _BlockGroup(num_blocks, num_channels[1], num_channels[2], 2, 134 | activation_fn=activation_fn), 135 | _BlockGroup(num_blocks, num_channels[2], num_channels[3], 2, 136 | activation_fn=activation_fn)) 137 | self.batchnorm = nn.BatchNorm2d(num_channels[3], momentum=0.01) 138 | self.relu = activation_fn(inplace=True) 139 | self.logits = nn.Linear(num_channels[3], num_classes) 140 | self.num_channels = num_channels[3] 141 | 142 | for m in self.modules(): 143 | if isinstance(m, nn.Conv2d): 144 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 145 | m.weight.data.normal_(0, math.sqrt(2. / n)) 146 | elif isinstance(m, nn.BatchNorm2d): 147 | m.weight.data.fill_(1) 148 | m.bias.data.zero_() 149 | elif isinstance(m, nn.Linear): 150 | m.bias.data.zero_() 151 | 152 | def forward(self, x): 153 | if self.padding > 0: 154 | x = F.pad(x, (self.padding,) * 4) 155 | if x.is_cuda: 156 | if self.mean_cuda is None: 157 | self.mean_cuda = self.mean.cuda() 158 | self.std_cuda = self.std.cuda() 159 | out = (x - self.mean_cuda) / self.std_cuda 160 | else: 161 | out = (x - self.mean) / self.std 162 | 163 | out = self.init_conv(out) 164 | out = self.layer(out) 165 | out = self.relu(self.batchnorm(out)) 166 | out = F.avg_pool2d(out, 8) 167 | out = out.view(-1, self.num_channels) 168 | return self.logits(out) 169 | 170 | 171 | def wideresnetwithswish(name, dataset='cifar10', num_classes=10, device='cpu'): 172 | """ 173 | Returns suitable Wideresnet model with Swish activation function from its name. 174 | Arguments: 175 | name (str): name of resnet architecture. 176 | num_classes (int): number of target classes. 177 | device (str or torch.device): device to work on. 178 | dataset (str): dataset to use. 179 | Returns: 180 | torch.nn.Module. 181 | """ 182 | # if 'cifar10' not in dataset: 183 | # raise ValueError('WideResNets with Swish activation only support CIFAR-10 and CIFAR-100!') 184 | 185 | name_parts = name.split('-') 186 | depth = int(name_parts[1]) 187 | widen = int(name_parts[2]) 188 | act_fn = name_parts[3] 189 | 190 | print (f'WideResNet-{depth}-{widen}-{act_fn} uses normalization.') 191 | if 'cifar100' in dataset: 192 | return WideResNet(num_classes=num_classes, depth=depth, width=widen, activation_fn=_ACTIVATION[act_fn], 193 | mean=CIFAR100_MEAN, std=CIFAR100_STD) 194 | elif 'svhn' in dataset: 195 | return WideResNet(num_classes=num_classes, depth=depth, width=widen, activation_fn=_ACTIVATION[act_fn], 196 | mean=SVHN_MEAN, std=SVHN_STD) 197 | return WideResNet(num_classes=num_classes, depth=depth, width=widen, activation_fn=_ACTIVATION[act_fn]) -------------------------------------------------------------------------------- /pytorch.yaml: -------------------------------------------------------------------------------- 1 | name: pytorch 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - bottleneck=1.3.5=py310ha9d4c09_0 11 | - brotli=1.0.9=h5eee18b_7 12 | - brotli-bin=1.0.9=h5eee18b_7 13 | - brotlipy=0.7.0=py310h7f8727e_1002 14 | - bzip2=1.0.8=h7b6447c_0 15 | - ca-certificates=2022.07.19=h06a4308_0 16 | - certifi=2022.9.24=py310h06a4308_0 17 | - cffi=1.15.1=py310h74dc2b5_0 18 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 19 | - cryptography=37.0.1=py310h9ce1e76_0 20 | - cudatoolkit=11.3.1=h2bc3f7f_2 21 | - cycler=0.11.0=pyhd3eb1b0_0 22 | - dbus=1.13.18=hb2f20db_0 23 | - expat=2.4.9=h6a678d5_0 24 | - ffmpeg=4.3=hf484d3e_0 25 | - fftw=3.3.9=h27cfd23_1 26 | - fontconfig=2.13.1=h6c09931_0 27 | - fonttools=4.25.0=pyhd3eb1b0_0 28 | - freetype=2.11.0=h70c0345_0 29 | - ftfy=5.8=py_0 30 | - giflib=5.2.1=h7b6447c_0 31 | - glib=2.69.1=h4ff587b_1 32 | - gmp=6.2.1=h295c915_3 33 | - gnutls=3.6.15=he1e5248_0 34 | - gst-plugins-base=1.14.0=h8213a91_2 35 | - gstreamer=1.14.0=h28cd5cc_2 36 | - icu=58.2=he6710b0_3 37 | - idna=3.4=py310h06a4308_0 38 | - intel-openmp=2021.4.0=h06a4308_3561 39 | - joblib=1.1.0=pyhd3eb1b0_0 40 | - jpeg=9e=h7f8727e_0 41 | - kiwisolver=1.4.2=py310h295c915_0 42 | - krb5=1.19.2=hac12032_0 43 | - lame=3.100=h7b6447c_0 44 | - lcms2=2.12=h3be6417_0 45 | - ld_impl_linux-64=2.38=h1181459_1 46 | - lerc=3.0=h295c915_0 47 | - libbrotlicommon=1.0.9=h5eee18b_7 48 | - libbrotlidec=1.0.9=h5eee18b_7 49 | - libbrotlienc=1.0.9=h5eee18b_7 50 | - libclang=10.0.1=default_hb85057a_2 51 | - libdeflate=1.8=h7f8727e_5 52 | - libedit=3.1.20210910=h7f8727e_0 53 | - libevent=2.1.12=h8f2d780_0 54 | - libffi=3.3=he6710b0_2 55 | - libgcc-ng=11.2.0=h1234567_1 56 | - libgfortran-ng=11.2.0=h00389a5_1 57 | - libgfortran5=11.2.0=h1234567_1 58 | - libgomp=11.2.0=h1234567_1 59 | - libiconv=1.16=h7f8727e_2 60 | - libidn2=2.3.2=h7f8727e_0 61 | - libllvm10=10.0.1=hbcb73fb_5 62 | - libpng=1.6.37=hbc83047_0 63 | - libpq=12.9=h16c4e8d_3 64 | - libstdcxx-ng=11.2.0=h1234567_1 65 | - libtasn1=4.16.0=h27cfd23_0 66 | - libtiff=4.4.0=hecacb30_0 67 | - libunistring=0.9.10=h27cfd23_0 68 | - libuuid=1.0.3=h7f8727e_2 69 | - libwebp=1.2.4=h11a3e52_0 70 | - libwebp-base=1.2.4=h5eee18b_0 71 | - libxcb=1.15=h7f8727e_0 72 | - libxkbcommon=1.0.1=hfa300c1_0 73 | - libxml2=2.9.14=h74e7548_0 74 | - libxslt=1.1.35=h4e12654_0 75 | - lz4-c=1.9.3=h295c915_1 76 | - matplotlib=3.5.2=py310h06a4308_0 77 | - matplotlib-base=3.5.2=py310hf590b9c_0 78 | - mkl=2021.4.0=h06a4308_640 79 | - mkl-service=2.4.0=py310h7f8727e_0 80 | - mkl_fft=1.3.1=py310hd6ae3a3_0 81 | - mkl_random=1.2.2=py310h00e6091_0 82 | - munkres=1.1.4=py_0 83 | - ncurses=6.3=h5eee18b_3 84 | - nettle=3.7.3=hbbd107a_1 85 | - nspr=4.33=h295c915_0 86 | - nss=3.74=h0370c37_0 87 | - numexpr=2.8.3=py310hcea2de6_0 88 | - numpy=1.23.3 89 | - numpy-base=1.23.3 90 | - openh264=2.1.1=h4ff587b_0 91 | - openssl=1.1.1q=h7f8727e_0 92 | - packaging=21.3=pyhd3eb1b0_0 93 | - pandas=1.4.4=py310h6a678d5_0 94 | - patsy=0.5.3=pyhd8ed1ab_0 95 | - pcre=8.45=h295c915_0 96 | - pillow=9.2.0=py310hace64e9_1 97 | - pip=22.2.2=py310h06a4308_0 98 | - ply=3.11=py310h06a4308_0 99 | - pycparser=2.21=pyhd3eb1b0_0 100 | - pyopenssl=22.0.0=pyhd3eb1b0_0 101 | - pyparsing=3.0.9=py310h06a4308_0 102 | - pyqt=5.15.7=py310h6a678d5_1 103 | - pysocks=1.7.1=py310h06a4308_0 104 | - python=3.10.6=haa1d7c7_0 105 | - python-dateutil=2.8.2=pyhd3eb1b0_0 106 | - python_abi=3.10=2_cp310 107 | - pytorch=1.12.1=py3.10_cuda11.3_cudnn8.3.2_0 108 | - pytorch-mutex=1.0=cuda 109 | - pytz=2022.1=py310h06a4308_0 110 | - qt-main=5.15.2=h327a75a_7 111 | - qt-webengine=5.15.9=hd2b0992_4 112 | - qtwebkit=5.212=h4eab89a_4 113 | - readline=8.1.2=h7f8727e_1 114 | - regex=2022.7.9=py310h5eee18b_0 115 | - requests=2.28.1=py310h06a4308_0 116 | - scikit-learn=1.1.2=py310h6a678d5_0 117 | - scipy=1.9.1=py310hd5efca6_0 118 | - seaborn=0.12.1=hd8ed1ab_0 119 | - seaborn-base=0.12.1=pyhd8ed1ab_0 120 | - setuptools=63.4.1=py310h06a4308_0 121 | - sip=6.6.2=py310h6a678d5_0 122 | - six=1.16.0=pyhd3eb1b0_1 123 | - sqlite=3.39.3=h5082296_0 124 | - statsmodels=0.13.2=py310h96516ba_0 125 | - threadpoolctl=2.2.0=pyh0d69192_0 126 | - tk=8.6.12=h1ccaba5_0 127 | - toml=0.10.2=pyhd3eb1b0_0 128 | - torchaudio=0.12.1=py310_cu113 129 | - torchvision=0.13.1=py310_cu113 130 | - tornado=6.2=py310h5eee18b_0 131 | - tqdm=4.64.1=py310h06a4308_0 132 | - typing_extensions=4.3.0=py310h06a4308_0 133 | - tzdata=2022e=h04d1e81_0 134 | - urllib3=1.26.11=py310h06a4308_0 135 | - wcwidth=0.2.5=pyhd3eb1b0_0 136 | - wheel=0.37.1=pyhd3eb1b0_0 137 | - xz=5.2.6=h5eee18b_0 138 | - zlib=1.2.12=h5eee18b_3 139 | - zstd=1.5.2=ha4553b6_0 140 | - pip: 141 | - pyqt5-sip==12.11.0 142 | prefix: /home/zhengchenyu/miniconda3/envs/pytorch 143 | -------------------------------------------------------------------------------- /scripts/main_aug.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | export PYTHONPATH=$PYTHONPATH:`pwd` 3 | 4 | python main_aug.py \ 5 | --desc CDCGAN \ 6 | --model CDCGAN \ 7 | --model-path "./log/cDCGAN/G-weights-last.pt" 8 | 9 | python main_aug.py \ 10 | --desc stylegan \ 11 | --model stylegan \ 12 | --model-path ./log/stylegan/cifar10.pkl -------------------------------------------------------------------------------- /scripts/main_bGMM.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:`pwd` 2 | 3 | # d=1 4 | python ./main_simulation_bGMM.py \ 5 | --train_size_list 20,50,100,200,500 \ 6 | --gamma_list 0,1,2,5,10,20,50 \ 7 | --d_list 1 \ 8 | --mode m 9 | 10 | # m_S = 10 11 | python ./main_simulation_bGMM.py \ 12 | --train_size_list 10 \ 13 | --gamma_list 0,1,2,5,10,20,50 \ 14 | --d_list 2,10,20,50,100 \ 15 | --mode d 16 | 17 | # d = 1, m_S = 40 18 | python ./main_simulation_bGMM.py \ 19 | --train_size_list 40 \ 20 | --gamma_list 0,1,2,5,10,20,50,100 \ 21 | --d_list 1 \ 22 | --mode gamma 23 | 24 | # d = 50, m_S = 10 25 | python ./main_simulation_bGMM.py \ 26 | --train_size_list 10 \ 27 | --gamma_list 0,1,2,5,10,20,50,100 \ 28 | --d_list 50 \ 29 | --mode gamma 30 | -------------------------------------------------------------------------------- /scripts/main_train_aug.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | export PYTHONPATH=$PYTHONPATH:'pwd' 3 | python main_train_aug.py --data-dir ./datasets \ 4 | --log-dir ./log/RN18_cDCGAN_base \ 5 | --desc RN18_cifar10s_lr0p2_epoch100_bs512_1000k \ 6 | --data cifar10s \ 7 | --batch-size 512 \ 8 | --model resnet18 \ 9 | --num-epochs 100 \ 10 | --eval-freq 10 \ 11 | --lr 0.2 \ 12 | --augment base \ 13 | --aux-data-filename ./datasets/cDCGAN/1000k.npz 14 | 15 | export CUDA_VISIBLE_DEVICES=3 16 | export PYTHONPATH=$PYTHONPATH:'pwd' 17 | python main_train_aug.py --data-dir ./datasets \ 18 | --log-dir ./log/RN18_stylegan_base \ 19 | --desc RN18_cifar10s_lr0p2_epoch100_bs512_100k \ 20 | --data cifar10s \ 21 | --batch-size 512 \ 22 | --model resnet18 \ 23 | --num-epochs 100 \ 24 | --eval-freq 10 \ 25 | --lr 0.2 \ 26 | --augment base \ 27 | --aux-data-filename ./datasets/stylegan/100k.npz 28 | 29 | export CUDA_VISIBLE_DEVICES=1 30 | export PYTHONPATH=$PYTHONPATH:'pwd' 31 | python main_train_aug.py --data-dir ./datasets \ 32 | --log-dir ./log/RN18_stylegan \ 33 | --desc RN18_cifar10s_lr0p2_epoch100_bs512_100k \ 34 | --data cifar10s \ 35 | --batch-size 512 \ 36 | --model resnet18 \ 37 | --num-epochs 100 \ 38 | --eval-freq 10 \ 39 | --lr 0.2 \ 40 | --aux-data-filename ./datasets/stylegan/100k.npz 41 | 42 | 43 | export CUDA_VISIBLE_DEVICES=1 44 | export PYTHONPATH=$PYTHONPATH:'pwd' 45 | python main_train_aug.py --data-dir ./datasets \ 46 | --log-dir ./log/RN18_EDM \ 47 | --desc RN18_cifar10s_lr0p2_epoch100_bs512_1000k \ 48 | --data cifar10s \ 49 | --batch-size 512 \ 50 | --model resnet18 \ 51 | --num-epochs 100 \ 52 | --eval-freq 10 \ 53 | --lr 0.2 \ 54 | --aux-data-filename ../bishe/codes/data/5m.npz \ 55 | --aux-take-amount 1000000 -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | import traceback 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | activation_funcs = { 24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 33 | } 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | _inited = False 38 | _plugin = None 39 | _null_tensor = torch.empty([0]) 40 | 41 | def _init(): 42 | global _inited, _plugin 43 | if not _inited: 44 | _inited = True 45 | sources = ['bias_act.cpp', 'bias_act.cu'] 46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 47 | try: 48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 49 | except: 50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 51 | return _plugin is not None 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 56 | r"""Fused bias and activation function. 57 | 58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 59 | and scales the result by `gain`. Each of the steps is optional. In most cases, 60 | the fused op is considerably more efficient than performing the same calculation 61 | using standard PyTorch ops. It supports first and second order gradients, 62 | but not third order gradients. 63 | 64 | Args: 65 | x: Input activation tensor. Can be of any shape. 66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 67 | as `x`. The shape must be known, and it must match the dimension of `x` 68 | corresponding to `dim`. 69 | dim: The dimension in `x` corresponding to the elements of `b`. 70 | The value of `dim` is ignored if `b` is not specified. 71 | act: Name of the activation function to evaluate, or `"linear"` to disable. 72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 73 | See `activation_funcs` for a full list. `None` is not allowed. 74 | alpha: Shape parameter for the activation function, or `None` to use the default. 75 | gain: Scaling factor for the output tensor, or `None` to use default. 76 | See `activation_funcs` for the default scaling of each activation function. 77 | If unsure, consider specifying 1. 78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 79 | the clamping (default). 80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 81 | 82 | Returns: 83 | Tensor of the same shape and datatype as `x`. 84 | """ 85 | assert isinstance(x, torch.Tensor) 86 | assert impl in ['ref', 'cuda'] 87 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | @misc.profiled_function 94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 96 | """ 97 | assert isinstance(x, torch.Tensor) 98 | assert clamp is None or clamp >= 0 99 | spec = activation_funcs[act] 100 | alpha = float(alpha if alpha is not None else spec.def_alpha) 101 | gain = float(gain if gain is not None else spec.def_gain) 102 | clamp = float(clamp if clamp is not None else -1) 103 | 104 | # Add bias. 105 | if b is not None: 106 | assert isinstance(b, torch.Tensor) and b.ndim == 1 107 | assert 0 <= dim < x.ndim 108 | assert b.shape[0] == x.shape[dim] 109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 110 | 111 | # Evaluate activation function. 112 | alpha = float(alpha) 113 | x = spec.func(x, alpha=alpha) 114 | 115 | # Scale by gain. 116 | gain = float(gain) 117 | if gain != 1: 118 | x = x * gain 119 | 120 | # Clamp. 121 | if clamp >= 0: 122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 123 | return x 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | _bias_act_cuda_cache = dict() 128 | 129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 130 | """Fast CUDA implementation of `bias_act()` using custom ops. 131 | """ 132 | # Parse arguments. 133 | assert clamp is None or clamp >= 0 134 | spec = activation_funcs[act] 135 | alpha = float(alpha if alpha is not None else spec.def_alpha) 136 | gain = float(gain if gain is not None else spec.def_gain) 137 | clamp = float(clamp if clamp is not None else -1) 138 | 139 | # Lookup from cache. 140 | key = (dim, act, alpha, gain, clamp) 141 | if key in _bias_act_cuda_cache: 142 | return _bias_act_cuda_cache[key] 143 | 144 | # Forward op. 145 | class BiasActCuda(torch.autograd.Function): 146 | @staticmethod 147 | def forward(ctx, x, b): # pylint: disable=arguments-differ 148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 149 | x = x.contiguous(memory_format=ctx.memory_format) 150 | b = b.contiguous() if b is not None else _null_tensor 151 | y = x 152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 154 | ctx.save_for_backward( 155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 157 | y if 'y' in spec.ref else _null_tensor) 158 | return y 159 | 160 | @staticmethod 161 | def backward(ctx, dy): # pylint: disable=arguments-differ 162 | dy = dy.contiguous(memory_format=ctx.memory_format) 163 | x, b, y = ctx.saved_tensors 164 | dx = None 165 | db = None 166 | 167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 168 | dx = dy 169 | if act != 'linear' or gain != 1 or clamp >= 0: 170 | dx = BiasActCudaGrad.apply(dy, x, b, y) 171 | 172 | if ctx.needs_input_grad[1]: 173 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 174 | 175 | return dx, db 176 | 177 | # Backward op. 178 | class BiasActCudaGrad(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 183 | ctx.save_for_backward( 184 | dy if spec.has_2nd_grad else _null_tensor, 185 | x, b, y) 186 | return dx 187 | 188 | @staticmethod 189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 191 | dy, x, b, y = ctx.saved_tensors 192 | d_dy = None 193 | d_x = None 194 | d_b = None 195 | d_y = None 196 | 197 | if ctx.needs_input_grad[0]: 198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 199 | 200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 202 | 203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 205 | 206 | return d_dy, d_x, d_b, d_y 207 | 208 | # Add to cache. 209 | _bias_act_cuda_cache[key] = BiasActCuda 210 | return BiasActCuda 211 | 212 | #---------------------------------------------------------------------------- 213 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import warnings 13 | import contextlib 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | 25 | @contextlib.contextmanager 26 | def no_weight_gradients(): 27 | global weight_gradients_disabled 28 | old = weight_gradients_disabled 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 54 | return True 55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 56 | return False 57 | 58 | def _tuple_of_ints(xs, ndim): 59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 60 | assert len(xs) == ndim 61 | assert all(isinstance(x, int) for x in xs) 62 | return xs 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | _conv2d_gradfix_cache = dict() 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | if not transpose: 112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 113 | else: # transpose 114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 115 | ctx.save_for_backward(input, weight) 116 | return output 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | input, weight = ctx.saved_tensors 121 | grad_input = None 122 | grad_weight = None 123 | grad_bias = None 124 | 125 | if ctx.needs_input_grad[0]: 126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 128 | assert grad_input.shape == input.shape 129 | 130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 131 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 132 | assert grad_weight.shape == weight_shape 133 | 134 | if ctx.needs_input_grad[2]: 135 | grad_bias = grad_output.sum([0, 2, 3]) 136 | 137 | return grad_input, grad_weight, grad_bias 138 | 139 | # Gradient with respect to the weights. 140 | class Conv2dGradWeight(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, grad_output, input): 143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 146 | assert grad_weight.shape == weight_shape 147 | ctx.save_for_backward(grad_output, input) 148 | return grad_weight 149 | 150 | @staticmethod 151 | def backward(ctx, grad2_grad_weight): 152 | grad_output, input = ctx.saved_tensors 153 | grad2_grad_output = None 154 | grad2_input = None 155 | 156 | if ctx.needs_input_grad[0]: 157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 158 | assert grad2_grad_output.shape == grad_output.shape 159 | 160 | if ctx.needs_input_grad[1]: 161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 163 | assert grad2_input.shape == input.shape 164 | 165 | return grad2_grad_output, grad2_input 166 | 167 | _conv2d_gradfix_cache[key] = Conv2d 168 | return Conv2d 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | w = w.flip([2, 3]) 37 | 38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 42 | if out_channels <= 4 and groups == 1: 43 | in_shape = x.shape 44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 46 | else: 47 | x = x.to(memory_format=torch.contiguous_format) 48 | w = w.to(memory_format=torch.contiguous_format) 49 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 50 | return x.to(memory_format=torch.channels_last) 51 | 52 | # Otherwise => execute using conv2d_gradfix. 53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 54 | return op(x, w, stride=stride, padding=padding, groups=groups) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | @misc.profiled_function 59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 60 | r"""2D convolution with optional up/downsampling. 61 | 62 | Padding is performed only once at the beginning, not between the operations. 63 | 64 | Args: 65 | x: Input tensor of shape 66 | `[batch_size, in_channels, in_height, in_width]`. 67 | w: Weight tensor of shape 68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 70 | calling upfirdn2d.setup_filter(). None = identity (default). 71 | up: Integer upsampling factor (default: 1). 72 | down: Integer downsampling factor (default: 1). 73 | padding: Padding with respect to the upsampled image. Can be a single number 74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 75 | (default: 0). 76 | groups: Split input channels into N groups (default: 1). 77 | flip_weight: False = convolution, True = correlation (default: True). 78 | flip_filter: False = convolution, True = correlation (default: False). 79 | 80 | Returns: 81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 82 | """ 83 | # Validate arguments. 84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 87 | assert isinstance(up, int) and (up >= 1) 88 | assert isinstance(down, int) and (down >= 1) 89 | assert isinstance(groups, int) and (groups >= 1) 90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 91 | fw, fh = _get_filter_size(f) 92 | px0, px1, py0, py1 = _parse_padding(padding) 93 | 94 | # Adjust padding to account for up/downsampling. 95 | if up > 1: 96 | px0 += (fw + up - 1) // 2 97 | px1 += (fw - up) // 2 98 | py0 += (fh + up - 1) // 2 99 | py1 += (fh - up) // 2 100 | if down > 1: 101 | px0 += (fw - down + 1) // 2 102 | px1 += (fw - down) // 2 103 | py0 += (fh - down + 1) // 2 104 | py1 += (fh - down) // 2 105 | 106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 107 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 110 | return x 111 | 112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 113 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 116 | return x 117 | 118 | # Fast path: downsampling only => use strided convolution. 119 | if down > 1 and up == 1: 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 122 | return x 123 | 124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 125 | if up > 1: 126 | if groups == 1: 127 | w = w.transpose(0, 1) 128 | else: 129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 130 | w = w.transpose(1, 2) 131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 132 | px0 -= kw - 1 133 | px1 -= kw - up 134 | py0 -= kh - 1 135 | py1 -= kh - up 136 | pxt = max(min(-px0, -px1), 0) 137 | pyt = max(min(-py0, -py1), 0) 138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 140 | if down > 1: 141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 142 | return x 143 | 144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 145 | if up == 1 and down == 1: 146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 148 | 149 | # Fallback: Generic reference implementation. 150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 152 | if down > 1: 153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 154 | return x 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def grid_sample(input, grid): 28 | if _should_use_custom_op(): 29 | return _GridSample2dForward.apply(input, grid) 30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _should_use_custom_op(): 35 | if not enabled: 36 | return False 37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 38 | return True 39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 40 | return False 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | class _GridSample2dForward(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input, grid): 47 | assert input.ndim == 4 48 | assert grid.ndim == 4 49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 50 | ctx.save_for_backward(input, grid) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, grid = ctx.saved_tensors 56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 57 | return grad_input, grad_grid 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class _GridSample2dBackward(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, grad_output, input, grid): 64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger 2 | from .tools import seed_torch, format_time, weights_init_normal, sample_image_grid 3 | from .parser import parser_train 4 | from .train import Trainer -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class Logger(object): 5 | """ 6 | Helper class for logging. 7 | Arguments: 8 | path (str): Path to log file. 9 | """ 10 | def __init__(self, path): 11 | self.logger = logging.getLogger() 12 | self.path = path 13 | self.setup_file_logger() 14 | print ('Logging to file: ', self.path) 15 | 16 | def setup_file_logger(self): 17 | hdlr = logging.FileHandler(self.path, 'w+') 18 | self.logger.addHandler(hdlr) 19 | self.logger.setLevel(logging.INFO) 20 | 21 | def log(self, message): 22 | print (message) 23 | self.logger.info(message) 24 | -------------------------------------------------------------------------------- /utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from data import DATASETS 4 | from models import MODELS 5 | from .train import SCHEDULERS 6 | 7 | from .tools import str2bool, str2float 8 | 9 | 10 | def parser_train(): 11 | """ 12 | Parse input arguments (train.py). 13 | """ 14 | parser = argparse.ArgumentParser(description='Standard Training.') 15 | 16 | parser.add_argument('--augment', type=str, default='none', choices=['none', 'base', 'cutout', 'autoaugment', 'randaugment', 'idbh'], help='Augment training set.') 17 | 18 | parser.add_argument('--batch-size', type=int, default=512, help='Batch size for training.') 19 | parser.add_argument('--batch-size-validation', type=int, default=256, help='Batch size for testing.') 20 | 21 | parser.add_argument('--data-dir', type=str, default='./datasets') 22 | parser.add_argument('--log-dir', type=str, default='./log') 23 | 24 | parser.add_argument('--data', type=str, default='cifar10s', choices=DATASETS, help='Data to use.') 25 | parser.add_argument('--desc', type=str, required=True, 26 | help='Description of experiment. It will be used to name directories.') 27 | 28 | parser.add_argument('--model', choices=MODELS, default='resnet18', help='Model architecture to be used.') 29 | parser.add_argument('--normalize', type=str2bool, default=False, help='Normalize input.') 30 | parser.add_argument('--pretrained-file', type=str, default=None, help='Pretrained weights file name.') 31 | 32 | parser.add_argument('--num-epochs', type=int, default=100, help='Number of training epochs.') 33 | parser.add_argument('--eval-freq', type=int, default=10, help='evaluation frequency (in epochs).') 34 | 35 | parser.add_argument('--beta', default=None, type=float, help='Stability regularization, i.e., 1/lambda in TRADES.') 36 | 37 | parser.add_argument('--lr', type=float, default=0.2, help='Learning rate for optimizer (SGD).') 38 | parser.add_argument('--weight-decay', type=float, default=5e-4, help='Optimizer (SGD) weight decay.') 39 | parser.add_argument('--scheduler', choices=SCHEDULERS, default='cosinew', help='Type of scheduler.') 40 | parser.add_argument('--nesterov', type=str2bool, default=False, help='Use Nesterov momentum.') 41 | parser.add_argument('--clip-grad', type=float, default=None, help='Gradient norm clipping.') 42 | 43 | parser.add_argument('--debug', action='store_true', default=False, 44 | help='Debug code. Run 1 epoch of training and evaluation.') 45 | 46 | parser.add_argument('--unsup-fraction', type=float, default=0.7, help='Ratio of unlabelled data to labelled data.') 47 | parser.add_argument('--aux-data-filename', type=str, help='Path to additional Tiny Images data.', 48 | default=None) 49 | parser.add_argument('--aux-take-amount', type=int, default=None, help='Number of augmentation.') 50 | 51 | parser.add_argument('--seed', type=int, default=1, help='Random seed.') 52 | 53 | ### Consistency 54 | parser.add_argument('--consistency', action='store_true', default=False, help='use Consistency.') 55 | parser.add_argument('--cons_lambda', type=float, default=1.0, help='lambda for Consistency.') 56 | parser.add_argument('--cons_tem', type=float, default=0.5, help='temperature for Consistency.') 57 | 58 | ### Resume 59 | parser.add_argument('--resume_path', default='', type=str) 60 | 61 | ### Our methods 62 | parser.add_argument('--LSE', action='store_true', default=False, help='LSE training.') 63 | parser.add_argument('--ls', type=float, default=0., help='label smoothing.') 64 | parser.add_argument('--clip_value', default=0, type=float) 65 | parser.add_argument('--CutMix', action='store_true', default=False, help='use CutMix.') 66 | return parser 67 | 68 | 69 | def parser_eval(): 70 | """ 71 | Parse input arguments (eval-adv.py, eval-corr.py, eval-aa.py). 72 | """ 73 | parser = argparse.ArgumentParser(description='Robustness evaluation.') 74 | 75 | parser.add_argument('--data-dir', type=str, default='/cluster/home/rarade/data/') 76 | parser.add_argument('--log-dir', type=str, default='/cluster/scratch/rarade/test/') 77 | 78 | parser.add_argument('--desc', type=str, required=True, help='Description of model to be evaluated.') 79 | parser.add_argument('--num-samples', type=int, default=1000, help='Number of test samples.') 80 | 81 | # eval-aa.py 82 | parser.add_argument('--train', action='store_true', default=False, help='Evaluate on training set.') 83 | parser.add_argument('-v', '--version', type=str, default='standard', choices=['custom', 'plus', 'standard'], 84 | help='Version of AA.') 85 | 86 | # eval-adv.py 87 | parser.add_argument('--source', type=str, default=None, help='Path to source model for black-box evaluation.') 88 | parser.add_argument('--wb', action='store_true', default=False, help='Perform white-box PGD evaluation.') 89 | 90 | # eval-rb.py 91 | parser.add_argument('--threat', type=str, default='corruptions', choices=['corruptions', 'Linf', 'L2'], 92 | help='Threat model for RobustBench evaluation.') 93 | 94 | parser.add_argument('--seed', type=int, default=1, help='Random seed.') 95 | return parser 96 | 97 | -------------------------------------------------------------------------------- /utils/rst.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | 6 | class CosineLR(torch.optim.lr_scheduler._LRScheduler): 7 | """ 8 | Cosine annealing LR schedule (used in Carmon et al, 2019). 9 | """ 10 | def __init__(self, optimizer, max_lr, epochs, last_epoch=-1): 11 | self.max_lr = max_lr 12 | self.epochs = epochs 13 | self._reset() 14 | super(CosineLR, self).__init__(optimizer, last_epoch) 15 | 16 | def _reset(self): 17 | self.current_lr = self.max_lr 18 | self.current_epoch = 1 19 | 20 | def step(self): 21 | self.current_lr = self.max_lr * 0.5 * (1 + np.cos((self.current_epoch - 1) / self.epochs * np.pi)) 22 | for param_group in self.optimizer.param_groups: 23 | param_group['lr'] = self.current_lr 24 | self.current_epoch += 1 25 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 26 | 27 | def get_lr(self): 28 | return self.current_lr 29 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import datetime 5 | import random 6 | import numpy as np 7 | import torch 8 | from scipy import interpolate 9 | 10 | from tqdm import tqdm 11 | import argparse 12 | 13 | from torchvision.utils import save_image 14 | 15 | def format_time(elapsed): 16 | """ 17 | Format time for displaying. 18 | Arguments: 19 | elapsed: time interval in seconds. 20 | """ 21 | elapsed_rounded = int(round((elapsed))) 22 | return str(datetime.timedelta(seconds=elapsed_rounded)) 23 | 24 | 25 | def get_console_file_logger(name, level=logging.INFO, logdir='./baseline'): 26 | logger = logging.Logger(name) 27 | logger.setLevel(level=level) 28 | logger.handlers = [] 29 | BASIC_FORMAT = "%(asctime)s, %(levelname)s:%(name)s:%(message)s" 30 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S' 31 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT) 32 | chlr = logging.StreamHandler() 33 | chlr.setFormatter(formatter) 34 | chlr.setLevel(level=level) 35 | 36 | fhlr = logging.FileHandler(os.path.join(logdir, str(time.time()) + '.log')) 37 | fhlr.setFormatter(formatter) 38 | logger.addHandler(chlr) 39 | logger.addHandler(fhlr) 40 | 41 | return logger 42 | 43 | def adjust_learning_rate(optimizer, epoch, args): 44 | """Decay the learning rate based on schedule""" 45 | lr = args.lr 46 | for milestone in args.schedule: 47 | lr *= 0.1 if epoch >= milestone else 1. 48 | for param_group in optimizer.param_groups: 49 | param_group['lr'] = lr 50 | 51 | def seed_torch(seed=2333): 52 | random.seed(seed) 53 | os.environ['PYTHONHASHSEED'] = str(seed) 54 | # np.random.seed(seed) 55 | torch.manual_seed(seed) 56 | torch.cuda.manual_seed(seed) 57 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 58 | torch.backends.cudnn.benchmark = True 59 | torch.backends.cudnn.enabled = True 60 | torch.backends.cudnn.deterministic = False 61 | 62 | def weights_init_normal(m): 63 | classname = m.__class__.__name__ 64 | if classname.find("Conv") != -1: 65 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 66 | elif classname.find("BatchNorm2d") != -1: 67 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 68 | torch.nn.init.constant_(m.bias.data, 0.0) 69 | 70 | def sample_image_grid(generator, z, n_row, batches_done, img_dir): 71 | """Saves a grid of generated digits ranging from 0 to n_classes""" 72 | # Get labels ranging from 0 to n_classes for n rows 73 | labels = np.array([num for _ in range(n_row) for num in range(n_row)]) 74 | labels = torch.from_numpy(labels).cuda() 75 | # generator.eval() 76 | with torch.no_grad(): 77 | gen_imgs = generator(z, labels) 78 | # generator.train() 79 | gen_imgs = gen_imgs * 0.5 + 0.5 80 | image_name = "%d.png" % batches_done 81 | image_path = os.path.join(img_dir, image_name) 82 | save_image(gen_imgs.data, image_path, nrow=n_row, normalize=True) 83 | 84 | def str2bool(v): 85 | """ 86 | Parse boolean using argument parser. 87 | """ 88 | if isinstance(v, bool): 89 | return v 90 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 91 | return True 92 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 93 | return False 94 | else: 95 | raise argparse.ArgumentTypeError('Boolean value expected.') 96 | 97 | def str2float(x): 98 | """ 99 | Parse float and fractions using argument parser. 100 | """ 101 | if '/' in x: 102 | n, d = x.split('/') 103 | return float(n)/float(d) 104 | else: 105 | try: 106 | return float(x) 107 | except: 108 | raise argparse.ArgumentTypeError('Fraction or float value expected.') 109 | 110 | def accuracy(true, preds): 111 | """ 112 | Computes multi-class accuracy. 113 | Arguments: 114 | true (torch.Tensor): true labels. 115 | preds (torch.Tensor): predicted labels. 116 | Returns: 117 | Multi-class accuracy. 118 | """ 119 | accuracy = (torch.softmax(preds, dim=1).argmax(dim=1) == true).sum().float()/float(true.size(0)) 120 | return accuracy.item() 121 | 122 | def interpolate_pos_embed(model, checkpoint_model): 123 | if 'pos_embed' in checkpoint_model: 124 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 125 | embedding_size = pos_embed_checkpoint.shape[-1] 126 | num_patches = model.patch_embed.num_patches 127 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 128 | # height (== width) for the checkpoint position embedding 129 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 130 | # height (== width) for the new position embedding 131 | new_size = int(num_patches ** 0.5) 132 | # class_token and dist_token are kept unchanged 133 | if orig_size != new_size: 134 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 135 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 136 | # only the position tokens are interpolated 137 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 138 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 139 | pos_tokens = torch.nn.functional.interpolate( 140 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 141 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 142 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 143 | checkpoint_model['pos_embed'] = new_pos_embed 144 | 145 | def load_pretrained(model, args): 146 | print(f">>>>>>>>>> Linear eval from {args.backbone_path} ..........") 147 | checkpoint = torch.load(args.backbone_path, map_location='cpu') 148 | checkpoint_model = checkpoint['model'] 149 | 150 | if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]): 151 | checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.')} 152 | print('Detect pre-trained model, remove [encoder.] prefix.') 153 | else: 154 | print('Detect non-pre-trained model, pass without doing anything.') 155 | 156 | assert args.backbone == 'simmim' 157 | print(f">>>>>>>>>> Remapping pre-trained keys for VIT ..........") 158 | checkpoint = remap_pretrained_keys_vit(model, checkpoint_model) 159 | 160 | msg = model.load_state_dict(checkpoint_model, strict=False) 161 | print(msg) 162 | 163 | del checkpoint 164 | torch.cuda.empty_cache() 165 | print(f">>>>>>>>>> loaded successfully from '{args.backbone_path}'") 166 | 167 | def remap_pretrained_keys_vit(model, checkpoint_model): 168 | # Duplicate shared rel_pos_bias to each layer 169 | if getattr(model, 'use_rel_pos_bias', False) and "rel_pos_bias.relative_position_bias_table" in checkpoint_model: 170 | print("Expand the shared relative position embedding to each transformer block.") 171 | num_layers = model.get_num_layers() 172 | rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"] 173 | for i in range(num_layers): 174 | checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone() 175 | checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") 176 | 177 | # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size 178 | all_keys = list(checkpoint_model.keys()) 179 | for key in all_keys: 180 | if "relative_position_index" in key: 181 | checkpoint_model.pop(key) 182 | 183 | if "relative_position_bias_table" in key: 184 | rel_pos_bias = checkpoint_model[key] 185 | src_num_pos, num_attn_heads = rel_pos_bias.size() 186 | dst_num_pos, _ = model.state_dict()[key].size() 187 | dst_patch_shape = model.patch_embed.patch_shape 188 | if dst_patch_shape[0] != dst_patch_shape[1]: 189 | raise NotImplementedError() 190 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) 191 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 192 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 193 | if src_size != dst_size: 194 | print("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size)) 195 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 196 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 197 | 198 | def geometric_progression(a, r, n): 199 | return a * (1.0 - r ** n) / (1.0 - r) 200 | 201 | left, right = 1.01, 1.5 202 | while right - left > 1e-6: 203 | q = (left + right) / 2.0 204 | gp = geometric_progression(1, q, src_size // 2) 205 | if gp > dst_size // 2: 206 | right = q 207 | else: 208 | left = q 209 | 210 | # if q > 1.090307: 211 | # q = 1.090307 212 | 213 | dis = [] 214 | cur = 1 215 | for i in range(src_size // 2): 216 | dis.append(cur) 217 | cur += q ** (i + 1) 218 | 219 | r_ids = [-_ for _ in reversed(dis)] 220 | 221 | x = r_ids + [0] + dis 222 | y = r_ids + [0] + dis 223 | 224 | t = dst_size // 2.0 225 | dx = np.arange(-t, t + 0.1, 1.0) 226 | dy = np.arange(-t, t + 0.1, 1.0) 227 | 228 | print("Original positions = %s" % str(x)) 229 | print("Target positions = %s" % str(dx)) 230 | 231 | all_rel_pos_bias = [] 232 | 233 | for i in range(num_attn_heads): 234 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 235 | f = interpolate.interp2d(x, y, z, kind='cubic') 236 | all_rel_pos_bias.append( 237 | torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) 238 | 239 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 240 | 241 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) 242 | checkpoint_model[key] = new_rel_pos_bias 243 | 244 | return checkpoint_model -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from tqdm import tqdm as tqdm 4 | 5 | import os 6 | import json 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .tools import accuracy 12 | from models import create_model 13 | 14 | # from .context import ctx_noparamgrad_and_eval 15 | from .tools import seed_torch 16 | 17 | # from .mart import mart_loss 18 | from .rst import CosineLR 19 | # from .trades import trades_loss 20 | 21 | 22 | 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | SCHEDULERS = ['cyclic', 'step', 'cosine', 'cosinew'] 27 | 28 | 29 | class Trainer(object): 30 | """ 31 | Helper class for training a deep neural network. 32 | Arguments: 33 | info (dict): dataset information. 34 | args (dict): input arguments. 35 | """ 36 | def __init__(self, info, args): 37 | super(Trainer, self).__init__() 38 | 39 | seed_torch(args.seed) 40 | self.model = create_model(args.model, args.normalize, info, device) 41 | 42 | self.params = args 43 | self.criterion = nn.CrossEntropyLoss() 44 | self.init_optimizer(self.params.num_epochs) 45 | 46 | if self.params.pretrained_file is not None: 47 | self.load_model(os.path.join(self.params.log_dir, self.params.pretrained_file, 'weights-best.pt')) 48 | 49 | # self.attack, self.eval_attack = self.init_attack(self.model, self.criterion, self.params.attack, self.params.attack_eps, 50 | # self.params.attack_iter, self.params.attack_step) 51 | 52 | 53 | # @staticmethod 54 | # def init_attack(model, criterion, attack_type, attack_eps, attack_iter, attack_step): 55 | # """ 56 | # Initialize adversary. 57 | # """ 58 | # attack = create_attack(model, criterion, attack_type, attack_eps, attack_iter, attack_step, rand_init_type='uniform') 59 | # if attack_type in ['linf-pgd', 'l2-pgd']: 60 | # eval_attack = create_attack(model, criterion, attack_type, attack_eps, 2*attack_iter, attack_step) 61 | # elif attack_type in ['fgsm', 'linf-df']: 62 | # eval_attack = create_attack(model, criterion, 'linf-pgd', 8/255, 20, 2/255) 63 | # elif attack_type in ['fgm', 'l2-df']: 64 | # eval_attack = create_attack(model, criterion, 'l2-pgd', 128/255, 20, 15/255) 65 | # return attack, eval_attack 66 | 67 | 68 | def init_optimizer(self, num_epochs): 69 | """ 70 | Initialize optimizer and scheduler. 71 | """ 72 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay, 73 | momentum=0.9, nesterov=self.params.nesterov) 74 | if num_epochs <= 0: 75 | return 76 | self.init_scheduler(num_epochs) 77 | 78 | 79 | def init_scheduler(self, num_epochs): 80 | """ 81 | Initialize scheduler. 82 | """ 83 | if self.params.scheduler == 'cyclic': 84 | num_samples = 50000 if 'cifar10' in self.params.data else 73257 85 | num_samples = 100000 if 'tiny-imagenet' in self.params.data else num_samples 86 | update_steps = int(np.floor(num_samples/self.params.batch_size) + 1) 87 | self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=self.params.lr, pct_start=0.25, 88 | steps_per_epoch=update_steps, epochs=int(num_epochs)) 89 | elif self.params.scheduler == 'step': 90 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, gamma=0.1, milestones=[100, 105]) 91 | elif self.params.scheduler == 'cosine': 92 | self.scheduler = CosineLR(self.optimizer, max_lr=self.params.lr, epochs=int(num_epochs)) 93 | elif self.params.scheduler == 'cosinew': 94 | self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=self.params.lr, pct_start=0.025, 95 | total_steps=int(num_epochs)) 96 | else: 97 | self.scheduler = None 98 | 99 | 100 | def train(self, dataloader, epoch=0, adversarial=False, verbose=False): 101 | """ 102 | Run one epoch of training. 103 | """ 104 | metrics = pd.DataFrame() 105 | self.model.train() 106 | 107 | for data in tqdm(dataloader, desc='Epoch {}: '.format(epoch), disable=not verbose): 108 | x, y = data 109 | x, y = x.to(device), y.to(device) 110 | 111 | if adversarial: 112 | if self.params.beta is not None and self.params.mart: 113 | pass 114 | # loss, batch_metrics = self.mart_loss(x, y, beta=self.params.beta) 115 | elif self.params.beta is not None: 116 | pass 117 | # loss, batch_metrics = self.trades_loss(x, y, beta=self.params.beta) 118 | else: 119 | loss, batch_metrics = self.adversarial_loss(x, y) 120 | else: 121 | loss, batch_metrics = self.standard_loss(x, y) 122 | 123 | loss.backward() 124 | if self.params.clip_grad: 125 | nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip_grad) 126 | self.optimizer.step() 127 | if self.params.scheduler in ['cyclic']: 128 | self.scheduler.step() 129 | 130 | metrics = metrics.append(pd.DataFrame(batch_metrics, index=[0]), ignore_index=True) 131 | 132 | if self.params.scheduler in ['step', 'converge', 'cosine', 'cosinew']: 133 | self.scheduler.step() 134 | return dict(metrics.mean()) 135 | 136 | 137 | def standard_loss(self, x, y): 138 | """ 139 | Standard training. 140 | """ 141 | self.optimizer.zero_grad() 142 | out = self.model(x) 143 | loss = self.criterion(out, y) 144 | 145 | preds = out.detach() 146 | batch_metrics = {'loss': loss.item(), 'clean_acc': accuracy(y, preds)} 147 | return loss, batch_metrics 148 | 149 | 150 | # def adversarial_loss(self, x, y): 151 | # """ 152 | # Adversarial training (Madry et al, 2017). 153 | # """ 154 | # with ctx_noparamgrad_and_eval(self.model): 155 | # x_adv, _ = self.attack.perturb(x, y) 156 | 157 | # self.optimizer.zero_grad() 158 | # if self.params.keep_clean: 159 | # x_adv = torch.cat((x, x_adv), dim=0) 160 | # y_adv = torch.cat((y, y), dim=0) 161 | # else: 162 | # y_adv = y 163 | # out = self.model(x_adv) 164 | # loss = self.criterion(out, y_adv) 165 | 166 | # preds = out.detach() 167 | # batch_metrics = {'loss': loss.item()} 168 | # if self.params.keep_clean: 169 | # preds_clean, preds_adv = preds[:len(x)], preds[len(x):] 170 | # batch_metrics.update({'clean_acc': accuracy(y, preds_clean), 'adversarial_acc': accuracy(y, preds_adv)}) 171 | # else: 172 | # batch_metrics.update({'adversarial_acc': accuracy(y, preds)}) 173 | # return loss, batch_metrics 174 | 175 | 176 | # def trades_loss(self, x, y, beta): 177 | # """ 178 | # TRADES training. 179 | # """ 180 | # loss, batch_metrics = trades_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 181 | # epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 182 | # beta=beta, attack=self.params.attack) 183 | # return loss, batch_metrics 184 | 185 | 186 | # def mart_loss(self, x, y, beta): 187 | # """ 188 | # MART training. 189 | # """ 190 | # loss, batch_metrics = mart_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 191 | # epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 192 | # beta=beta, attack=self.params.attack) 193 | # return loss, batch_metrics 194 | 195 | 196 | def eval(self, dataloader, adversarial=False): 197 | """ 198 | Evaluate performance of the model. 199 | """ 200 | acc = 0.0 201 | loss = 0.0 202 | self.model.eval() 203 | 204 | with torch.no_grad(): 205 | for x, y in dataloader: 206 | x, y = x.to(device), y.to(device) 207 | if adversarial: 208 | # with ctx_noparamgrad_and_eval(self.model): 209 | # x_adv, _ = self.eval_attack.perturb(x, y) 210 | # out = self.model(x_adv) 211 | pass 212 | else: 213 | out = self.model(x) 214 | 215 | loss += self.criterion(out, y).item() 216 | acc += accuracy(y, out) 217 | 218 | loss /= len(dataloader) 219 | acc /= len(dataloader) 220 | return loss, acc 221 | 222 | 223 | def save_model(self, path): 224 | """ 225 | Save model weights. 226 | """ 227 | torch.save({'model_state_dict': self.model.state_dict()}, path) 228 | 229 | 230 | def load_model(self, path, load_opt=True): 231 | """ 232 | Load model weights. 233 | """ 234 | checkpoint = torch.load(path) 235 | if 'model_state_dict' not in checkpoint: 236 | raise RuntimeError('Model weights not found at {}.'.format(path)) 237 | self.model.load_state_dict(checkpoint['model_state_dict']) --------------------------------------------------------------------------------