├── README.md ├── best_model ├── cifar100_wrn40_2.pth └── cifar10_wrn40_2.pth ├── datafree ├── __init__.py ├── criterions.py ├── datasets │ ├── __init__.py │ ├── nyu.py │ ├── tiny_imagenet.py │ └── utils.py ├── evaluators.py ├── hooks.py ├── metrics │ ├── __init__.py │ ├── accuracy.py │ ├── confusion_matrix.py │ ├── running_average.py │ └── stream_metrics.py ├── models │ ├── __init__.py │ ├── classifiers │ │ ├── __init__.py │ │ ├── lenet.py │ │ ├── mobilenetv2.py │ │ ├── resnet.py │ │ ├── resnet_in.py │ │ ├── resnet_tiny.py │ │ ├── shufflenetv2.py │ │ ├── vgg.py │ │ └── wresnet.py │ ├── deeplab │ │ ├── __init__.py │ │ ├── _deeplab.py │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ ├── mobilenetv2.py │ │ │ └── resnet.py │ │ ├── modeling.py │ │ └── utils.py │ ├── generator.py │ └── stylegan_generator.py ├── rep_transfer.py ├── synthesis │ ├── __init__.py │ ├── base.py │ ├── contrastive.py │ └── triplet.py └── utils │ ├── __init__.py │ ├── _utils.py │ ├── fmix.py │ ├── inception.py │ ├── logger.py │ ├── pair.py │ ├── sync_transforms │ ├── __init__.py │ ├── functional.py │ └── transforms.py │ └── vis.py ├── losses.py ├── main.py ├── misc └── framework.png ├── registry.py └── train_scratch.py /README.md: -------------------------------------------------------------------------------- 1 | ## News 2 | * `2024/12/20` We release the code for the *data-free knowledge distillation* tasks. 3 | 4 | # RGAL 5 | 6 | This is a PyTorch implementation of the following paper: 7 | 8 | **Relation-Guided Adversarial Learning for Data-Free Knowledge Transfer**, IJCV 2024. 9 | 10 | Yingping Liang and Ying Fu 11 | 12 | [Paper](https://link.springer.com/article/10.1007/s11263-024-02303-4) 13 | 14 | 15 | 16 | **Abstract**: *Data-free knowledge distillation transfers knowledge by recovering training data from a pre-trained model. Despite the recent success of seeking global data diversity, the diversity within each class and the similarity among different classes are largely overlooked, resulting in data homogeneity and limited performance. In this paper, we introduce a novel Relation-Guided Adversarial Learning method with triplet losses, which solves the homogeneity problem from two aspects. To be specific, our method aims to promote both intra-class diversity and inter-class confusion of the generated samples. To this end, we design two phases, an image synthesis phase and a student training phase. In the image synthesis phase, we construct an optimization process to push away samples with the same labels and pull close samples with different labels, leading to intra-class diversity and inter-class confusion, respectively. Then, in the student training phase, we perform an opposite optimization, which adversarially attempts to reduce the distance of samples of the same classes and enlarge the distance of samples of different classes. To mitigate the conflict of seeking high global diversity and keeping inter-class confusing, we propose a focal weighted sampling strategy by selecting the negative in the triplets unevenly within a finite range of distance. RGAL shows significant improvement over previous state-of-the-art methods in accuracy and data efficiency. Besides, RGAL can be inserted into state-of-the-art methods on various data-free knowledge transfer applications. Experiments on various benchmarks demonstrate the effectiveness and generalizability of our proposed method on various tasks, specially data-free knowledge distillation, data-free quantization, and non-exemplar incremental learning.* 17 | 18 | 19 | 20 | 21 | https://github.com/user-attachments/assets/eb78306f-1fbe-465a-9996-7315716f0b55 22 | 23 | 24 | 25 | 26 | 27 | ## Instillation 28 | 29 | ``` 30 | conda create -n rgal python=3.9 31 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 32 | pip install scipy tqdm pillow kornia 33 | ``` 34 | 35 | ## Run 36 | 37 | The dataset (CIFAR-10/-100) will be downloaded automatically when running. 38 | 39 | We provide a running script: 40 | ``` 41 | python main.py \ 42 | --epochs 200 \ 43 | --dataset cifar10 \ 44 | --batch_size 128 \ 45 | --synthesis_batch_size 256 \ 46 | --teacher wrn40_2 \ 47 | --student wrn16_1 \ 48 | --lr 0.1 \ 49 | --kd_steps 400 \ 50 | --ep_steps 400 \ 51 | --g_steps 400 \ 52 | --lr_g 1e-3 \ 53 | --adv 1.0 \ 54 | --bn 1.0 \ 55 | --oh 1.0 \ 56 | --act 0.001 \ 57 | --gpu 0 \ 58 | --seed 0 \ 59 | --T 20 \ 60 | --save_dir run/scratch1 \ 61 | --log_tag scratch1 \ 62 | --cd_loss 0.1 \ 63 | --gram_loss 0 \ 64 | --teacher_weights best_model/cifar10_wrn40_2.pth \ 65 | --custom_steps 1.0 \ 66 | --print_freq 50 \ 67 | --triplet_target student \ 68 | --pair_sample \ 69 | --striplet_feature global \ 70 | --start_layer 2 \ 71 | --triplet 0.1 \ 72 | --striplet 0.1 \ 73 | --balanced_sampling \ 74 | --balance 0.1 75 | ``` 76 | 77 | where "--triplet" and "--striplet" indicates the loss weights of our proposed in the data generation stage and distillation stage, separately. 78 | 79 | To running our method on different teacher and student models, modify "--teacher" and "--student wrn16_1" 80 | 81 | "--balanced_sampling" indicates the paired sampling strategy as in our paper. 82 | 83 | Pretrained checkpoints for examples are available at (best_model)[https://github.com/Sharpiless/RGAL/tree/main/best_model]. 84 | 85 | ![image](https://github.com/user-attachments/assets/3c8b7698-7f11-430c-ac6d-d7d0b4a22a7f) 86 | 87 | 88 | ## Visualization 89 | 90 | Please refer to (ZSKT)[https://github.com/polo5/ZeroShotKnowledgeTransfer]. 91 | 92 | ## License and Citation 93 | This repository can only be used for personal/research/non-commercial purposes. 94 | Please cite the following paper if this model helps your research: 95 | 96 | ``` 97 | @article{liang2024relation, 98 | title={Relation-Guided Adversarial Learning for Data-Free Knowledge Transfer}, 99 | author={Liang, Yingping and Fu, Ying}, 100 | journal={International Journal of Computer Vision}, 101 | pages={1--18}, 102 | year={2024}, 103 | publisher={Springer} 104 | } 105 | ``` 106 | 107 | ## Acknowledgments 108 | * The code for inference and training is heavily borrowed from [CMI](https://github.com/zju-vipa/CMI), we thank the author for their great effort. 109 | -------------------------------------------------------------------------------- /best_model/cifar100_wrn40_2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/RGAL/e94cf7b19bff1c1a517592a9d9bcaf521c768e43/best_model/cifar100_wrn40_2.pth -------------------------------------------------------------------------------- /best_model/cifar10_wrn40_2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/RGAL/e94cf7b19bff1c1a517592a9d9bcaf521c768e43/best_model/cifar10_wrn40_2.pth -------------------------------------------------------------------------------- /datafree/__init__.py: -------------------------------------------------------------------------------- 1 | from . import criterions, utils, metrics, hooks, rep_transfer, evaluators, synthesis, datasets 2 | -------------------------------------------------------------------------------- /datafree/criterions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | def kldiv(logits, targets, T=1.0, reduction='batchmean'): 6 | q = F.log_softmax(logits/T, dim=1) 7 | p = F.softmax(targets/T, dim=1) 8 | return F.kl_div(q, p, reduction=reduction) * (T*T) 9 | 10 | 11 | class KLDiv(nn.Module): 12 | def __init__(self, T=1.0, reduction='batchmean'): 13 | super().__init__() 14 | self.T = T 15 | self.reduction = reduction 16 | 17 | def forward(self, logits, targets): 18 | return kldiv(logits, targets, T=self.T, reduction=self.reduction) 19 | 20 | def jsdiv( logits, targets, T=1.0, reduction='batchmean' ): 21 | P = F.softmax(logits / T, dim=1) 22 | Q = F.softmax(targets / T, dim=1) 23 | M = 0.5 * (P + Q) 24 | P = torch.clamp(P, 0.01, 0.99) 25 | Q = torch.clamp(Q, 0.01, 0.99) 26 | M = torch.clamp(M, 0.01, 0.99) 27 | return 0.5 * F.kl_div(torch.log(P), M, reduction=reduction) + 0.5 * F.kl_div(torch.log(Q), M, reduction=reduction) 28 | 29 | def cross_entropy(logits, targets, reduction='mean'): 30 | return F.cross_entropy(logits, targets, reduction=reduction) 31 | 32 | def class_balance_loss(logits): 33 | prob = torch.softmax(logits, dim=1) 34 | avg_prob = prob.mean(dim=0) 35 | return (avg_prob * torch.log(avg_prob)).sum() 36 | 37 | def onehot_loss(logits, targets=None): 38 | if targets is None: 39 | targets = logits.max(1)[1] 40 | return cross_entropy(logits, targets) 41 | 42 | def get_image_prior_losses(inputs_jit): 43 | # COMPUTE total variation regularization loss 44 | diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:] 45 | diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :] 46 | diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:] 47 | diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:] 48 | #loss_var_l2 = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4) 49 | loss_var_l1 = (diff1.abs() / 255.0).mean() + (diff2.abs() / 255.0).mean() + ( 50 | diff3.abs() / 255.0).mean() + (diff4.abs() / 255.0).mean() 51 | loss_var_l1 = loss_var_l1 * 255.0 52 | return loss_var_l1 -------------------------------------------------------------------------------- /datafree/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .nyu import NYUv2 2 | from .tiny_imagenet import TinyImageNet -------------------------------------------------------------------------------- /datafree/datasets/nyu.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/VainF/nyuv2-python-toolkit 2 | import os 3 | import torch 4 | import torch.utils.data as data 5 | from PIL import Image 6 | from scipy.io import loadmat 7 | import numpy as np 8 | import glob 9 | from torchvision import transforms 10 | from torchvision.datasets import VisionDataset 11 | import random 12 | 13 | from .utils import colormap 14 | 15 | class NYUv2(VisionDataset): 16 | """NYUv2 dataset 17 | See https://github.com/VainF/nyuv2-python-toolkit for more details. 18 | 19 | Args: 20 | root (string): Root directory path. 21 | split (string, optional): 'train' for training set, and 'test' for test set. Default: 'train'. 22 | target_type (string, optional): Type of target to use, ``semantic``, ``depth`` or ``normal``. 23 | num_classes (int, optional): The number of classes, must be 40 or 13. Default:13. 24 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. 25 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 26 | transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. 27 | """ 28 | cmap = colormap() 29 | def __init__(self, 30 | root, 31 | split='train', 32 | target_type='semantic', 33 | num_classes=13, 34 | transforms=None, 35 | transform=None, 36 | target_transform=None): 37 | super( NYUv2, self ).__init__(root, transforms=transforms, transform=transform, target_transform=target_transform) 38 | assert(split in ('train', 'test')) 39 | 40 | self.root = root 41 | self.split = split 42 | self.target_type = target_type 43 | self.num_classes = num_classes 44 | 45 | split_mat = loadmat(os.path.join(self.root, 'splits.mat')) 46 | idxs = split_mat[self.split+'Ndxs'].reshape(-1) - 1 47 | 48 | img_names = os.listdir( os.path.join(self.root, 'image', self.split) ) 49 | img_names.sort() 50 | images_dir = os.path.join(self.root, 'image', self.split) 51 | self.images = [os.path.join(images_dir, name) for name in img_names] 52 | 53 | self._is_depth = False 54 | if self.target_type=='semantic': 55 | semantic_dir = os.path.join(self.root, 'seg%d'%self.num_classes, self.split) 56 | self.labels = [os.path.join(semantic_dir, name) for name in img_names] 57 | self.targets = self.labels 58 | 59 | if self.target_type=='depth': 60 | depth_dir = os.path.join(self.root, 'depth', self.split) 61 | self.depths = [os.path.join(depth_dir, name) for name in img_names] 62 | self.targets = self.depths 63 | self._is_depth = True 64 | 65 | if self.target_type=='normal': 66 | normal_dir = os.path.join(self.root, 'normal', self.split) 67 | self.normals = [os.path.join(normal_dir, name) for name in img_names] 68 | self.targets = self.normals 69 | 70 | def __getitem__(self, idx): 71 | image = Image.open(self.images[idx]) 72 | target = Image.open(self.targets[idx]) 73 | if self.transforms is not None: 74 | image, target = self.transforms( image, target ) 75 | return image, target 76 | 77 | def __len__(self): 78 | return len(self.images) 79 | 80 | @classmethod 81 | def decode_fn(cls, mask: np.ndarray): 82 | """decode semantic mask to RGB image""" 83 | mask = mask.astype('uint8') + 1 # 255 => 0 84 | return cls.cmap[mask] 85 | -------------------------------------------------------------------------------- /datafree/datasets/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import glob 3 | import numpy as np 4 | import os 5 | from torchvision.datasets.folder import pil_loader 6 | from torchvision.datasets.utils import download_and_extract_archive 7 | 8 | class TinyImageNet(Dataset): 9 | def __init__(self, root, split, transform, download=True): 10 | 11 | self.url = "http://cs231n.stanford.edu/tiny-imagenet-200" 12 | self.root = root 13 | if download: 14 | if os.path.exists(f'{self.root}/tiny-imagenet-200/'): 15 | print(f'{self.root}/tiny-imagenet-200/, File already downloaded') 16 | else: 17 | print(f'{self.root}/tiny-imagenet-200/, File isn\'t downloaded') 18 | download_and_extract_archive(self.url, root, filename="tiny-imagenet-200.zip") 19 | 20 | self.root = os.path.join(self.root, "tiny-imagenet-200") 21 | self.train = split == "train" 22 | self.transform = transform 23 | self.ids_string = np.sort(np.loadtxt(f"{self.root}/wnids.txt", "str")) 24 | self.ids = {class_string: i for i, class_string in enumerate(self.ids_string)} 25 | if self.train: 26 | self.paths = glob.glob(f"{self.root}/train/*/images/*") 27 | self.targets = [self.ids[path.split("/")[-3]] for path in self.paths] 28 | else: 29 | self.paths = glob.glob(f"{self.root}/val/*/images/*") 30 | self.targets = [self.ids[path.split("/")[-3]] for path in self.paths] 31 | 32 | def __len__(self): 33 | return len(self.paths) 34 | 35 | def __getitem__(self, idx): 36 | image = pil_loader(self.paths[idx]) 37 | 38 | if self.transform is not None: 39 | image = self.transform(image) 40 | 41 | return image, self.targets[idx] -------------------------------------------------------------------------------- /datafree/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def colormap(N=256, normalized=False): 5 | def bitget(byteval, idx): 6 | return ((byteval & (1 << idx)) != 0) 7 | 8 | dtype = 'float32' if normalized else 'uint8' 9 | cmap = np.zeros((N, 3), dtype=dtype) 10 | for i in range(N): 11 | r = g = b = 0 12 | c = i 13 | for j in range(8): 14 | r = r | (bitget(c, 0) << 7-j) 15 | g = g | (bitget(c, 1) << 7-j) 16 | b = b | (bitget(c, 2) << 7-j) 17 | c = c >> 3 18 | 19 | cmap[i] = np.array([r, g, b]) 20 | 21 | cmap = cmap/255 if normalized else cmap 22 | return cmap -------------------------------------------------------------------------------- /datafree/evaluators.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch.nn.functional as F 3 | import torch 4 | from . import metrics 5 | 6 | class Evaluator(object): 7 | def __init__(self, metric, dataloader): 8 | self.dataloader = dataloader 9 | self.metric = metric 10 | 11 | def eval(self, model, device=None, progress=False): 12 | self.metric.reset() 13 | with torch.no_grad(): 14 | for i, (inputs, targets) in enumerate( tqdm(self.dataloader, disable=not progress) ): 15 | inputs, targets = inputs.to(device), targets.to(device) 16 | outputs = model( inputs ) 17 | self.metric.update(outputs, targets) 18 | return self.metric.get_results() 19 | 20 | def __call__(self, *args, **kwargs): 21 | return self.eval(*args, **kwargs) 22 | 23 | class AdvEvaluator(object): 24 | def __init__(self, metric, dataloader, adversary): 25 | self.dataloader = dataloader 26 | self.metric = metric 27 | self.adversary = adversary 28 | 29 | def eval(self, model, device=None, progress=False): 30 | self.metric.reset() 31 | for i, (inputs, targets) in enumerate( tqdm(self.dataloader, disable=not progress) ): 32 | inputs, targets = inputs.to(device), targets.to(device) 33 | inputs = self.adversary.perturb(inputs, targets) 34 | with torch.no_grad(): 35 | outputs = model( inputs ) 36 | self.metric.update(outputs, targets) 37 | return self.metric.get_results() 38 | 39 | def __call__(self, *args, **kwargs): 40 | return self.eval(*args, **kwargs) 41 | 42 | def classification_evaluator(dataloader): 43 | metric = metrics.MetricCompose({ 44 | 'Acc': metrics.TopkAccuracy(), 45 | 'Loss': metrics.RunningLoss(torch.nn.CrossEntropyLoss(reduction='sum')) 46 | }) 47 | return Evaluator( metric, dataloader=dataloader) 48 | 49 | def advarsarial_classification_evaluator(dataloader, adversary): 50 | metric = metrics.MetricCompose({ 51 | 'Acc': metrics.TopkAccuracy(), 52 | 'Loss': metrics.RunningLoss(torch.nn.CrossEntropyLoss(reduction='sum')) 53 | }) 54 | return AdvEvaluator( metric, dataloader=dataloader, adversary=adversary) 55 | 56 | 57 | def segmentation_evaluator(dataloader, num_classes, ignore_idx=255): 58 | cm = metrics.ConfusionMatrix(num_classes, ignore_idx=ignore_idx) 59 | metric = metrics.MetricCompose({ 60 | 'mIoU': metrics.mIoU(cm), 61 | 'Acc': metrics.Accuracy(), 62 | 'Loss': metrics.RunningLoss(torch.nn.CrossEntropyLoss(reduction='sum')) 63 | }) 64 | return Evaluator( metric, dataloader=dataloader) -------------------------------------------------------------------------------- /datafree/hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def register_hooks(modules): 6 | hooks = [] 7 | for m in modules: 8 | hooks.append( FeatureHook(m) ) 9 | return hooks 10 | 11 | class InstanceMeanHook(object): 12 | def __init__(self, module): 13 | self.hook = module.register_forward_hook(self.hook_fn) 14 | self.module = module 15 | 16 | def hook_fn(self, module, input, output): 17 | self.instance_mean = torch.mean(input[0], dim=[2, 3]) 18 | 19 | def remove(self): 20 | self.hook.remove() 21 | 22 | def __repr__(self): 23 | return ": %s"%(self.module) 24 | 25 | class FeatureHook(object): 26 | def __init__(self, module): 27 | self.hook = module.register_forward_hook(self.hook_fn) 28 | self.module = module 29 | 30 | def hook_fn(self, module, input, output): 31 | self.output = output 32 | self.input = input[0] 33 | 34 | def remove(self): 35 | self.hook.remove() 36 | 37 | def __repr__(self): 38 | return ": %s"%(self.module) 39 | 40 | 41 | class FeatureMeanHook(object): 42 | def __init__(self, module): 43 | self.hook = module.register_forward_hook(self.hook_fn) 44 | self.module = module 45 | 46 | def hook_fn(self, module, input, output): 47 | self.instance_mean = torch.mean(input[0], dim=[2, 3]) 48 | 49 | def remove(self): 50 | self.hook.remove() 51 | 52 | def __repr__(self): 53 | return ": %s"%(self.module) 54 | 55 | 56 | class FeatureMeanVarHook(): 57 | def __init__(self, module, on_input=True, dim=[0,2,3]): 58 | self.hook = module.register_forward_hook(self.hook_fn) 59 | self.on_input = on_input 60 | self.module = module 61 | self.dim = dim 62 | 63 | def hook_fn(self, module, input, output): 64 | # To avoid inplace modification 65 | if self.on_input: 66 | feature = input[0].clone() 67 | else: 68 | feature = output.clone() 69 | self.var, self.mean = torch.var_mean( feature, dim=self.dim, unbiased=True ) 70 | 71 | def remove(self): 72 | self.hook.remove() 73 | self.output=None 74 | 75 | 76 | class DeepInversionHook(): 77 | ''' 78 | Implementation of the forward hook to track feature statistics and compute a loss on them. 79 | Will compute mean and variance, and will use l2 as a loss 80 | ''' 81 | def __init__(self, module): 82 | self.hook = module.register_forward_hook(self.hook_fn) 83 | self.module = module 84 | 85 | def hook_fn(self, module, input, output): 86 | # hook co compute deepinversion's feature distribution regularization 87 | nch = input[0].shape[1] 88 | mean = input[0].mean([0, 2, 3]) 89 | var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False) 90 | #forcing mean and variance to match between two distributions 91 | #other ways might work better, i.g. KL divergence 92 | r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm( 93 | module.running_mean.data - mean, 2) 94 | self.r_feature = r_feature 95 | 96 | def remove(self): 97 | self.hook.remove() -------------------------------------------------------------------------------- /datafree/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .stream_metrics import Metric, MetricCompose 2 | from .accuracy import Accuracy, TopkAccuracy 3 | from .confusion_matrix import ConfusionMatrix, IoU, mIoU 4 | from .running_average import RunningLoss 5 | 6 | -------------------------------------------------------------------------------- /datafree/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .stream_metrics import Metric 4 | from typing import Callable 5 | 6 | __all__=['Accuracy', 'TopkAccuracy'] 7 | 8 | class Accuracy(Metric): 9 | def __init__(self): 10 | self.reset() 11 | 12 | @torch.no_grad() 13 | def update(self, outputs, targets): 14 | outputs = outputs.max(1)[1] 15 | self._correct += ( outputs.view(-1)==targets.view(-1) ).sum() 16 | self._cnt += torch.numel( targets ) 17 | 18 | def get_results(self): 19 | return (self._correct / self._cnt * 100.).detach().cpu() 20 | 21 | def reset(self): 22 | self._correct = self._cnt = 0.0 23 | 24 | 25 | class TopkAccuracy(Metric): 26 | def __init__(self, topk=(1, 5)): 27 | self._topk = topk 28 | self.reset() 29 | 30 | @torch.no_grad() 31 | def update(self, outputs, targets): 32 | for k in self._topk: 33 | _, topk_outputs = outputs.topk(k, dim=1, largest=True, sorted=True) 34 | correct = topk_outputs.eq( targets.view(-1, 1).expand_as(topk_outputs) ) 35 | self._correct[k] += correct[:, :k].view(-1).float().sum(0).item() 36 | self._cnt += len(targets) 37 | 38 | def get_results(self): 39 | return tuple( self._correct[k] / self._cnt * 100. for k in self._topk ) 40 | 41 | def reset(self): 42 | self._correct = {k: 0 for k in self._topk} 43 | self._cnt = 0.0 -------------------------------------------------------------------------------- /datafree/metrics/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | from .stream_metrics import Metric 2 | import torch 3 | from typing import Callable 4 | 5 | class ConfusionMatrix(Metric): 6 | def __init__(self, num_classes, ignore_idx=None): 7 | super(ConfusionMatrix, self).__init__() 8 | self._num_classes = num_classes 9 | self._ignore_idx = ignore_idx 10 | self.reset() 11 | 12 | @torch.no_grad() 13 | def update(self, outputs, targets): 14 | if self.confusion_matrix.device != outputs.device: 15 | self.confusion_matrix = self.confusion_matrix.to(device=outputs.device) 16 | preds = outputs.max(1)[1].flatten() 17 | targets = targets.flatten() 18 | mask = (preds=0) 19 | if self._ignore_idx: 20 | mask = mask & (targets!=self._ignore_idx) 21 | preds, targets = preds[mask], targets[mask] 22 | hist = torch.bincount( self._num_classes * targets + preds, 23 | minlength=self._num_classes ** 2 ).view(self._num_classes, self._num_classes) 24 | self.confusion_matrix += hist 25 | 26 | def get_results(self): 27 | return self.confusion_matrix.detach().cpu() 28 | 29 | def reset(self): 30 | self._cnt = 0 31 | self.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.int64, requires_grad=False) 32 | 33 | class IoU(Metric): 34 | def __init__(self, confusion_matrix: ConfusionMatrix): 35 | self._confusion_matrix = confusion_matrix 36 | 37 | def update(self, outputs, targets): 38 | self._confusion_matrix.update(outputs, targets) 39 | 40 | def reset(self): 41 | self._confusion_matrix.reset() 42 | 43 | def get_results(self): 44 | cm = self._confusion_matrix.get_results() 45 | iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-9) 46 | return iou 47 | 48 | class mIoU(IoU): 49 | def get_results(self): 50 | return super(mIoU, self).get_results().mean() 51 | -------------------------------------------------------------------------------- /datafree/metrics/running_average.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .stream_metrics import Metric 4 | 5 | __all__=['Accuracy', 'TopkAccuracy'] 6 | 7 | class RunningLoss(Metric): 8 | def __init__(self, loss_fn, is_batch_average=False): 9 | self.reset() 10 | self.loss_fn = loss_fn 11 | self.is_batch_average = is_batch_average 12 | 13 | @torch.no_grad() 14 | def update(self, outputs, targets): 15 | self._accum_loss += self.loss_fn(outputs, targets) 16 | if self.is_batch_average: 17 | self._cnt += 1 18 | else: 19 | self._cnt += len(outputs) 20 | 21 | def get_results(self): 22 | return (self._accum_loss / self._cnt).detach().cpu() 23 | 24 | def reset(self): 25 | self._accum_loss = self._cnt = 0.0 26 | -------------------------------------------------------------------------------- /datafree/metrics/stream_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | from abc import ABC, abstractmethod 4 | from typing import Callable, Union, Any, Mapping, Sequence 5 | import numbers 6 | import numpy as np 7 | 8 | class Metric(ABC): 9 | @abstractmethod 10 | def update(self, pred, target): 11 | """ Overridden by subclasses """ 12 | raise NotImplementedError() 13 | 14 | @abstractmethod 15 | def get_results(self): 16 | """ Overridden by subclasses """ 17 | raise NotImplementedError() 18 | 19 | @abstractmethod 20 | def reset(self): 21 | """ Overridden by subclasses """ 22 | raise NotImplementedError() 23 | 24 | 25 | class MetricCompose(dict): 26 | def __init__(self, metric_dict: Mapping): 27 | self._metric_dict = metric_dict 28 | 29 | @property 30 | def metrics(self): 31 | return self._metric_dict 32 | 33 | @torch.no_grad() 34 | def update(self, outputs, targets): 35 | for key, metric in self._metric_dict.items(): 36 | if isinstance(metric, Metric): 37 | metric.update(outputs, targets) 38 | 39 | def get_results(self): 40 | results = {} 41 | for key, metric in self._metric_dict.items(): 42 | if isinstance(metric, Metric): 43 | results[key] = metric.get_results() 44 | return results 45 | 46 | def reset(self): 47 | for key, metric in self._metric_dict.items(): 48 | if isinstance(metric, Metric): 49 | metric.reset() 50 | 51 | def __getitem__(self, name): 52 | return self._metric_dict[name] 53 | 54 | 55 | -------------------------------------------------------------------------------- /datafree/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import classifiers 2 | from . import generator 3 | from . import deeplab -------------------------------------------------------------------------------- /datafree/models/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import lenet, wresnet, vgg, resnet, mobilenetv2, shufflenetv2, resnet_tiny, resnet_in -------------------------------------------------------------------------------- /datafree/models/classifiers/lenet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/huawei-noah/Data-Efficient-Model-Compression 2 | import torch.nn as nn 3 | 4 | class LeNet5(nn.Module): 5 | 6 | def __init__(self, nc=1, num_classes=10): 7 | super(LeNet5, self).__init__() 8 | self.features = nn.Sequential( 9 | nn.Conv2d(1, 6, kernel_size=(5, 5)), 10 | nn.ReLU(inplace=True), 11 | nn.MaxPool2d(kernel_size=(2, 2), stride=2), 12 | nn.Conv2d(6, 16, kernel_size=(5, 5)), 13 | nn.ReLU(inplace=True), 14 | nn.MaxPool2d(kernel_size=(2, 2), stride=2), 15 | nn.Conv2d(16, 120, kernel_size=(5, 5)), 16 | nn.ReLU(inplace=True), 17 | ) 18 | self.fc = nn.Sequential( 19 | nn.Linear(120, 84), 20 | nn.ReLU(inplace=True), 21 | nn.Linear(84, num_classes) 22 | ) 23 | 24 | def forward(self, img, return_features=False): 25 | features = self.features( img ).view(-1, 120) 26 | output = self.fc( features ) 27 | if return_features: 28 | return output, features 29 | return output 30 | 31 | 32 | class LeNet5Half(nn.Module): 33 | 34 | def __init__(self, nc=1, num_classes=10): 35 | super(LeNet5Half, self).__init__() 36 | self.features = nn.Sequential( 37 | nn.Conv2d(1, 3, kernel_size=(5, 5)), 38 | nn.ReLU(inplace=True), 39 | nn.MaxPool2d(kernel_size=(2, 2), stride=2), 40 | nn.Conv2d(3, 8, kernel_size=(5, 5)), 41 | nn.ReLU(inplace=True), 42 | nn.MaxPool2d(kernel_size=(2, 2), stride=2), 43 | nn.Conv2d(8, 60, kernel_size=(5, 5)), 44 | nn.ReLU(inplace=True), 45 | ) 46 | self.fc = nn.Sequential( 47 | nn.Linear(60, 42), 48 | nn.ReLU(inplace=True), 49 | nn.Linear(42, num_classes) 50 | ) 51 | 52 | def forward(self, img, return_features=False): 53 | features = self.features( img ).view(-1, 60) 54 | output = self.fc( features ) 55 | if return_features: 56 | return output, features 57 | return output -------------------------------------------------------------------------------- /datafree/models/classifiers/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import Tensor 3 | from torchvision.models.utils import load_state_dict_from_url 4 | from typing import Callable, Any, Optional, List 5 | 6 | 7 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 8 | 9 | 10 | model_urls = { 11 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 12 | } 13 | 14 | 15 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: 16 | """ 17 | This function is taken from the original tf repo. 18 | It ensures that all layers have a channel number that is divisible by 8 19 | It can be seen here: 20 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 21 | :param v: 22 | :param divisor: 23 | :param min_value: 24 | :return: 25 | """ 26 | if min_value is None: 27 | min_value = divisor 28 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 29 | # Make sure that round down does not go down by more than 10%. 30 | if new_v < 0.9 * v: 31 | new_v += divisor 32 | return new_v 33 | 34 | 35 | class ConvBNActivation(nn.Sequential): 36 | def __init__( 37 | self, 38 | in_planes: int, 39 | out_planes: int, 40 | kernel_size: int = 3, 41 | stride: int = 1, 42 | groups: int = 1, 43 | norm_layer: Optional[Callable[..., nn.Module]] = None, 44 | activation_layer: Optional[Callable[..., nn.Module]] = None, 45 | ) -> None: 46 | padding = (kernel_size - 1) // 2 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if activation_layer is None: 50 | activation_layer = nn.ReLU6 51 | super(ConvBNReLU, self).__init__( 52 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 53 | norm_layer(out_planes), 54 | activation_layer(inplace=True) 55 | ) 56 | 57 | 58 | # necessary for backwards compatibility 59 | ConvBNReLU = ConvBNActivation 60 | 61 | 62 | class InvertedResidual(nn.Module): 63 | def __init__( 64 | self, 65 | inp: int, 66 | oup: int, 67 | stride: int, 68 | expand_ratio: int, 69 | norm_layer: Optional[Callable[..., nn.Module]] = None 70 | ) -> None: 71 | super(InvertedResidual, self).__init__() 72 | self.stride = stride 73 | assert stride in [1, 2] 74 | 75 | if norm_layer is None: 76 | norm_layer = nn.BatchNorm2d 77 | 78 | hidden_dim = int(round(inp * expand_ratio)) 79 | self.use_res_connect = self.stride == 1 and inp == oup 80 | 81 | layers: List[nn.Module] = [] 82 | if expand_ratio != 1: 83 | # pw 84 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 85 | layers.extend([ 86 | # dw 87 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 88 | # pw-linear 89 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 90 | norm_layer(oup), 91 | ]) 92 | self.conv = nn.Sequential(*layers) 93 | 94 | def forward(self, x: Tensor) -> Tensor: 95 | if self.use_res_connect: 96 | return x + self.conv(x) 97 | else: 98 | return self.conv(x) 99 | 100 | 101 | class MobileNetV2(nn.Module): 102 | def __init__( 103 | self, 104 | num_classes: int = 1000, 105 | width_mult: float = 1.0, 106 | inverted_residual_setting: Optional[List[List[int]]] = None, 107 | round_nearest: int = 8, 108 | block: Optional[Callable[..., nn.Module]] = None, 109 | norm_layer: Optional[Callable[..., nn.Module]] = None 110 | ) -> None: 111 | """ 112 | MobileNet V2 main class 113 | Args: 114 | num_classes (int): Number of classes 115 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 116 | inverted_residual_setting: Network structure 117 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 118 | Set to 1 to turn off rounding 119 | block: Module specifying inverted residual building block for mobilenet 120 | norm_layer: Module specifying the normalization layer to use 121 | """ 122 | super(MobileNetV2, self).__init__() 123 | 124 | if block is None: 125 | block = InvertedResidual 126 | 127 | if norm_layer is None: 128 | norm_layer = nn.BatchNorm2d 129 | 130 | input_channel = 32 131 | last_channel = 1280 132 | 133 | if inverted_residual_setting is None: 134 | inverted_residual_setting = [ 135 | # t, c, n, s 136 | [1, 16, 1, 1], 137 | [6, 24, 2, 2], 138 | [6, 32, 3, 2], 139 | [6, 64, 4, 2], 140 | [6, 96, 3, 1], 141 | [6, 160, 3, 2], 142 | [6, 320, 1, 1], 143 | ] 144 | 145 | # only check the first element, assuming user knows t,c,n,s are required 146 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 147 | raise ValueError("inverted_residual_setting should be non-empty " 148 | "or a 4-element list, got {}".format(inverted_residual_setting)) 149 | 150 | # building first layer 151 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 152 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 153 | features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 154 | # building inverted residual blocks 155 | for t, c, n, s in inverted_residual_setting: 156 | output_channel = _make_divisible(c * width_mult, round_nearest) 157 | for i in range(n): 158 | stride = s if i == 0 else 1 159 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 160 | input_channel = output_channel 161 | # building last several layers 162 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 163 | # make it nn.Sequential 164 | self.features = nn.Sequential(*features) 165 | 166 | # building classifier 167 | self.classifier = nn.Sequential( 168 | nn.Dropout(0.2), 169 | nn.Linear(self.last_channel, num_classes), 170 | ) 171 | 172 | # weight initialization 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 176 | if m.bias is not None: 177 | nn.init.zeros_(m.bias) 178 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 179 | nn.init.ones_(m.weight) 180 | nn.init.zeros_(m.bias) 181 | elif isinstance(m, nn.Linear): 182 | nn.init.normal_(m.weight, 0, 0.01) 183 | nn.init.zeros_(m.bias) 184 | 185 | def _forward_impl(self, x: Tensor) -> Tensor: 186 | # This exists since TorchScript doesn't support inheritance, so the superclass method 187 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 188 | x = self.features(x) 189 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 190 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1) 191 | x = self.classifier(x) 192 | return x 193 | 194 | def forward(self, x: Tensor) -> Tensor: 195 | return self._forward_impl(x) 196 | 197 | 198 | def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2: 199 | """ 200 | Constructs a MobileNetV2 architecture from 201 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | progress (bool): If True, displays a progress bar of the download to stderr 205 | """ 206 | model = MobileNetV2(**kwargs) 207 | if pretrained: 208 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 209 | progress=progress) 210 | model.load_state_dict(state_dict) 211 | return model -------------------------------------------------------------------------------- /datafree/models/classifiers/resnet.py: -------------------------------------------------------------------------------- 1 | # ResNet for CIFAR (32x32) 2 | # 2019.07.24-Changed output of forward function 3 | # Huawei Technologies Co., Ltd. 4 | # taken from https://github.com/huawei-noah/Data-Efficient-Model-Compression/blob/master/DAFL/resnet.py 5 | # for comparison with DAFL 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1): 41 | super(Bottleneck, self).__init__() 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(self.expansion*planes) 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out += self.shortcut(x) 61 | out = F.relu(out) 62 | return out 63 | 64 | 65 | class ResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(ResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(64) 72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 76 | self.linear = nn.Linear(512*block.expansion, num_classes) 77 | 78 | def _make_layer(self, block, planes, num_blocks, stride): 79 | strides = [stride] + [1]*(num_blocks-1) 80 | layers = [] 81 | for stride in strides: 82 | layers.append(block(self.in_planes, planes, stride)) 83 | self.in_planes = planes * block.expansion 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x, return_features=False): 87 | x = self.conv1(x) 88 | x = self.bn1(x) 89 | x = F.relu(x) 90 | x1 = self.layer1(x) 91 | x2 = self.layer2(x1) 92 | x3 = self.layer3(x2) 93 | x4 = self.layer4(x3) 94 | out = F.adaptive_avg_pool2d(x4, (1,1)) 95 | feature = out.view(out.size(0), -1) 96 | out = self.linear(feature) 97 | 98 | if return_features: 99 | return out, feature, [x1, x2, x3, x4] 100 | return out 101 | 102 | def resnet18(num_classes=10): 103 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 104 | 105 | def resnet34(num_classes=10): 106 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 107 | 108 | def resnet50(num_classes=10): 109 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 110 | 111 | def resnet101(num_classes=10): 112 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 113 | 114 | def resnet152(num_classes=10): 115 | return ResNet(Bottleneck, [3,8,36,3], num_classes) -------------------------------------------------------------------------------- /datafree/models/classifiers/resnet_in.py: -------------------------------------------------------------------------------- 1 | # ResNet for ImageNet (224x224) 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision.models.utils import load_state_dict_from_url 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 10 | 'wide_resnet50_2', 'wide_resnet101_2'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=dilation, groups=groups, bias=False, dilation=dilation) 30 | 31 | 32 | def conv1x1(in_planes, out_planes, stride=1): 33 | """1x1 convolution""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 41 | base_width=64, dilation=1, norm_layer=None): 42 | super(BasicBlock, self).__init__() 43 | if norm_layer is None: 44 | norm_layer = nn.BatchNorm2d 45 | if groups != 1 or base_width != 64: 46 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 47 | if dilation > 1: 48 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 49 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 50 | self.conv1 = conv3x3(inplanes, planes, stride) 51 | self.bn1 = norm_layer(planes) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.conv2 = conv3x3(planes, planes) 54 | self.bn2 = norm_layer(planes) 55 | self.downsample = downsample 56 | self.stride = stride 57 | 58 | def forward(self, x): 59 | identity = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv2(out) 66 | out = self.bn2(out) 67 | 68 | if self.downsample is not None: 69 | identity = self.downsample(x) 70 | 71 | out += identity 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class Bottleneck(nn.Module): 78 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 79 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 80 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 81 | # This variant is also known as ResNet V1.5 and improves accuracy according to 82 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 83 | 84 | expansion = 4 85 | 86 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 87 | base_width=64, dilation=1, norm_layer=None): 88 | super(Bottleneck, self).__init__() 89 | if norm_layer is None: 90 | norm_layer = nn.BatchNorm2d 91 | width = int(planes * (base_width / 64.)) * groups 92 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 93 | self.conv1 = conv1x1(inplanes, width) 94 | self.bn1 = norm_layer(width) 95 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 96 | self.bn2 = norm_layer(width) 97 | self.conv3 = conv1x1(width, planes * self.expansion) 98 | self.bn3 = norm_layer(planes * self.expansion) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.downsample = downsample 101 | self.stride = stride 102 | 103 | def forward(self, x): 104 | identity = x 105 | 106 | out = self.conv1(x) 107 | out = self.bn1(out) 108 | out = self.relu(out) 109 | 110 | out = self.conv2(out) 111 | out = self.bn2(out) 112 | out = self.relu(out) 113 | 114 | out = self.conv3(out) 115 | out = self.bn3(out) 116 | 117 | if self.downsample is not None: 118 | identity = self.downsample(x) 119 | 120 | out += identity 121 | out = self.relu(out) 122 | 123 | return out 124 | 125 | 126 | class ResNet(nn.Module): 127 | 128 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 129 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 130 | norm_layer=None): 131 | super(ResNet, self).__init__() 132 | if norm_layer is None: 133 | norm_layer = nn.BatchNorm2d 134 | self._norm_layer = norm_layer 135 | 136 | self.inplanes = 64 137 | self.dilation = 1 138 | if replace_stride_with_dilation is None: 139 | # each element in the tuple indicates if we should replace 140 | # the 2x2 stride with a dilated convolution instead 141 | replace_stride_with_dilation = [False, False, False] 142 | if len(replace_stride_with_dilation) != 3: 143 | raise ValueError("replace_stride_with_dilation should be None " 144 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 145 | self.groups = groups 146 | self.base_width = width_per_group 147 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 148 | bias=False) 149 | self.bn1 = norm_layer(self.inplanes) 150 | self.relu = nn.ReLU(inplace=True) 151 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 152 | self.layer1 = self._make_layer(block, 64, layers[0]) 153 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 154 | dilate=replace_stride_with_dilation[0]) 155 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 156 | dilate=replace_stride_with_dilation[1]) 157 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 158 | dilate=replace_stride_with_dilation[2]) 159 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 160 | self.fc = nn.Linear(512 * block.expansion, num_classes) 161 | 162 | for m in self.modules(): 163 | if isinstance(m, nn.Conv2d): 164 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 165 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 166 | nn.init.constant_(m.weight, 1) 167 | nn.init.constant_(m.bias, 0) 168 | 169 | # Zero-initialize the last BN in each residual branch, 170 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 171 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 172 | if zero_init_residual: 173 | for m in self.modules(): 174 | if isinstance(m, Bottleneck): 175 | nn.init.constant_(m.bn3.weight, 0) 176 | elif isinstance(m, BasicBlock): 177 | nn.init.constant_(m.bn2.weight, 0) 178 | 179 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 180 | norm_layer = self._norm_layer 181 | downsample = None 182 | previous_dilation = self.dilation 183 | if dilate: 184 | self.dilation *= stride 185 | stride = 1 186 | if stride != 1 or self.inplanes != planes * block.expansion: 187 | downsample = nn.Sequential( 188 | conv1x1(self.inplanes, planes * block.expansion, stride), 189 | norm_layer(planes * block.expansion), 190 | ) 191 | 192 | layers = [] 193 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 194 | self.base_width, previous_dilation, norm_layer)) 195 | self.inplanes = planes * block.expansion 196 | for _ in range(1, blocks): 197 | layers.append(block(self.inplanes, planes, groups=self.groups, 198 | base_width=self.base_width, dilation=self.dilation, 199 | norm_layer=norm_layer)) 200 | 201 | return nn.Sequential(*layers) 202 | 203 | def _forward_impl(self, x, return_features): 204 | # See note [TorchScript super()] 205 | x = self.conv1(x) 206 | x = self.bn1(x) 207 | x = self.relu(x) 208 | x = self.maxpool(x) 209 | 210 | x1 = self.layer1(x) 211 | x2 = self.layer2(x1) 212 | x3 = self.layer3(x2) 213 | x4 = self.layer4(x3) 214 | 215 | x = self.avgpool(x4) 216 | feat = torch.flatten(x, 1) 217 | x = self.fc(feat) 218 | if return_features: 219 | return x, feat, [x1, x2, x3, x4] 220 | return x 221 | 222 | def forward(self, x, return_features=False): 223 | return self._forward_impl(x, return_features=return_features) 224 | 225 | 226 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 227 | model = ResNet(block, layers, **kwargs) 228 | if pretrained: 229 | state_dict = load_state_dict_from_url(model_urls[arch], 230 | progress=progress) 231 | print('load from', model_urls[arch]) 232 | model.load_state_dict(state_dict) 233 | return model 234 | 235 | 236 | def resnet18(pretrained=False, progress=True, **kwargs): 237 | r"""ResNet-18 model from 238 | `"Deep Residual Learning for Image Recognition" `_ 239 | Args: 240 | pretrained (bool): If True, returns a model pre-trained on ImageNet 241 | progress (bool): If True, displays a progress bar of the download to stderr 242 | """ 243 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 244 | **kwargs) 245 | 246 | 247 | def resnet34(pretrained=False, progress=True, **kwargs): 248 | r"""ResNet-34 model from 249 | `"Deep Residual Learning for Image Recognition" `_ 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet50(pretrained=False, progress=True, **kwargs): 259 | r"""ResNet-50 model from 260 | `"Deep Residual Learning for Image Recognition" `_ 261 | Args: 262 | pretrained (bool): If True, returns a model pre-trained on ImageNet 263 | progress (bool): If True, displays a progress bar of the download to stderr 264 | """ 265 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 266 | **kwargs) 267 | 268 | 269 | def resnet101(pretrained=False, progress=True, **kwargs): 270 | r"""ResNet-101 model from 271 | `"Deep Residual Learning for Image Recognition" `_ 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 277 | **kwargs) 278 | 279 | 280 | def resnet152(pretrained=False, progress=True, **kwargs): 281 | r"""ResNet-152 model from 282 | `"Deep Residual Learning for Image Recognition" `_ 283 | Args: 284 | pretrained (bool): If True, returns a model pre-trained on ImageNet 285 | progress (bool): If True, displays a progress bar of the download to stderr 286 | """ 287 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 288 | **kwargs) 289 | 290 | 291 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 292 | r"""ResNeXt-50 32x4d model from 293 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 294 | Args: 295 | pretrained (bool): If True, returns a model pre-trained on ImageNet 296 | progress (bool): If True, displays a progress bar of the download to stderr 297 | """ 298 | kwargs['groups'] = 32 299 | kwargs['width_per_group'] = 4 300 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 301 | pretrained, progress, **kwargs) 302 | 303 | 304 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 305 | r"""ResNeXt-101 32x8d model from 306 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 307 | Args: 308 | pretrained (bool): If True, returns a model pre-trained on ImageNet 309 | progress (bool): If True, displays a progress bar of the download to stderr 310 | """ 311 | kwargs['groups'] = 32 312 | kwargs['width_per_group'] = 8 313 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 314 | pretrained, progress, **kwargs) 315 | 316 | 317 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 318 | r"""Wide ResNet-50-2 model from 319 | `"Wide Residual Networks" `_ 320 | The model is the same as ResNet except for the bottleneck number of channels 321 | which is twice larger in every block. The number of channels in outer 1x1 322 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 323 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 324 | Args: 325 | pretrained (bool): If True, returns a model pre-trained on ImageNet 326 | progress (bool): If True, displays a progress bar of the download to stderr 327 | """ 328 | kwargs['width_per_group'] = 64 * 2 329 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 330 | pretrained, progress, **kwargs) 331 | 332 | 333 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 334 | r"""Wide ResNet-101-2 model from 335 | `"Wide Residual Networks" `_ 336 | The model is the same as ResNet except for the bottleneck number of channels 337 | which is twice larger in every block. The number of channels in outer 1x1 338 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 339 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 340 | Args: 341 | pretrained (bool): If True, returns a model pre-trained on ImageNet 342 | progress (bool): If True, displays a progress bar of the download to stderr 343 | """ 344 | kwargs['width_per_group'] = 64 * 2 345 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 346 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /datafree/models/classifiers/resnet_tiny.py: -------------------------------------------------------------------------------- 1 | # Tiny ResNet for CIFAR (32x32) 2 | 3 | from __future__ import absolute_import 4 | 5 | '''Resnet for cifar dataset. 6 | https://github.com/HobbitLong/RepDistiller/blob/master/models/resnet.py 7 | Ported form 8 | https://github.com/facebook/fb.resnet.torch 9 | and 10 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 11 | (c) YANG, Wei 12 | ''' 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import math 16 | 17 | 18 | __all__ = ['resnet'] 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | preact = out 55 | out = F.relu(out) 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | preact = out 94 | out = F.relu(out) 95 | return out 96 | 97 | 98 | class ResNet(nn.Module): 99 | 100 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10): 101 | super(ResNet, self).__init__() 102 | if block_name.lower() == 'basicblock': 103 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 104 | n = (depth - 2) // 6 105 | block = BasicBlock 106 | elif block_name.lower() == 'bottleneck': 107 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 108 | n = (depth - 2) // 9 109 | block = Bottleneck 110 | else: 111 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 112 | 113 | self.inplanes = num_filters[0] 114 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, 115 | bias=False) 116 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.layer1 = self._make_layer(block, num_filters[1], n) 119 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2) 120 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2) 121 | self.avgpool = nn.AvgPool2d(8) 122 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes) 123 | 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 127 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 128 | nn.init.constant_(m.weight, 1) 129 | nn.init.constant_(m.bias, 0) 130 | 131 | def _make_layer(self, block, planes, blocks, stride=1): 132 | downsample = None 133 | if stride != 1 or self.inplanes != planes * block.expansion: 134 | downsample = nn.Sequential( 135 | nn.Conv2d(self.inplanes, planes * block.expansion, 136 | kernel_size=1, stride=stride, bias=False), 137 | nn.BatchNorm2d(planes * block.expansion), 138 | ) 139 | 140 | layers = list([]) 141 | layers.append(block(self.inplanes, planes, stride, downsample)) 142 | self.inplanes = planes * block.expansion 143 | for i in range(1, blocks): 144 | layers.append(block(self.inplanes, planes)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x, return_features=False): 149 | x = self.conv1(x) 150 | x = self.bn1(x) 151 | x = self.relu(x) # 32x32 152 | x1 = self.layer1(x) 153 | x2 = self.layer2(x1) 154 | x3 = self.layer3(x2) 155 | x = self.avgpool(x3) 156 | features = x.view(x.size(0), -1) 157 | x = self.fc(features) 158 | 159 | if return_features: 160 | return x, features, [x1, x2, x3] 161 | return x 162 | 163 | 164 | def resnet8(num_classes): 165 | return ResNet(8, [16, 16, 32, 64], 'basicblock', num_classes=num_classes) 166 | 167 | 168 | def resnet14(num_classes): 169 | return ResNet(14, [16, 16, 32, 64], 'basicblock', num_classes=num_classes) 170 | 171 | 172 | def resnet20(num_classes): 173 | return ResNet(20, [16, 16, 32, 64], 'basicblock', num_classes=num_classes) 174 | 175 | 176 | def resnet32(num_classes): 177 | return ResNet(32, [16, 16, 32, 64], 'basicblock', num_classes=num_classes) 178 | 179 | 180 | def resnet44(num_classes): 181 | return ResNet(44, [16, 16, 32, 64], 'basicblock', num_classes=num_classes) 182 | 183 | 184 | def resnet56(num_classes): 185 | return ResNet(56, [16, 16, 32, 64], 'basicblock', num_classes=num_classes) 186 | 187 | 188 | def resnet110(num_classes): 189 | return ResNet(110, [16, 16, 32, 64], 'basicblock', num_classes=num_classes) 190 | 191 | 192 | def resnet8x4(num_classes): 193 | return ResNet(8, [32, 64, 128, 256], 'basicblock', num_classes=num_classes) 194 | 195 | 196 | def resnet32x4(num_classes): 197 | return ResNet(32, [32, 64, 128, 256], 'basicblock', num_classes=num_classes) -------------------------------------------------------------------------------- /datafree/models/classifiers/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | https://github.com/HobbitLong/RepDistiller/blob/34557d2728/models/ShuffleNetv2.py 3 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups=2): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N, C, H, W = x.size() 18 | g = self.groups 19 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 20 | 21 | 22 | class SplitBlock(nn.Module): 23 | def __init__(self, ratio): 24 | super(SplitBlock, self).__init__() 25 | self.ratio = ratio 26 | 27 | def forward(self, x): 28 | c = int(x.size(1) * self.ratio) 29 | return x[:, :c, :, :], x[:, c:, :, :] 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | def __init__(self, in_channels, split_ratio=0.5, is_last=False): 34 | super(BasicBlock, self).__init__() 35 | self.is_last = is_last 36 | self.split = SplitBlock(split_ratio) 37 | in_channels = int(in_channels * split_ratio) 38 | self.conv1 = nn.Conv2d(in_channels, in_channels, 39 | kernel_size=1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(in_channels) 41 | self.conv2 = nn.Conv2d(in_channels, in_channels, 42 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 43 | self.bn2 = nn.BatchNorm2d(in_channels) 44 | self.conv3 = nn.Conv2d(in_channels, in_channels, 45 | kernel_size=1, bias=False) 46 | self.bn3 = nn.BatchNorm2d(in_channels) 47 | self.shuffle = ShuffleBlock() 48 | 49 | def forward(self, x): 50 | x1, x2 = self.split(x) 51 | out = F.relu(self.bn1(self.conv1(x2))) 52 | out = self.bn2(self.conv2(out)) 53 | preact = self.bn3(self.conv3(out)) 54 | out = F.relu(preact) 55 | # out = F.relu(self.bn3(self.conv3(out))) 56 | preact = torch.cat([x1, preact], 1) 57 | out = torch.cat([x1, out], 1) 58 | out = self.shuffle(out) 59 | if self.is_last: 60 | return out, preact 61 | else: 62 | return out 63 | 64 | 65 | class DownBlock(nn.Module): 66 | def __init__(self, in_channels, out_channels): 67 | super(DownBlock, self).__init__() 68 | mid_channels = out_channels // 2 69 | # left 70 | self.conv1 = nn.Conv2d(in_channels, in_channels, 71 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 72 | self.bn1 = nn.BatchNorm2d(in_channels) 73 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 74 | kernel_size=1, bias=False) 75 | self.bn2 = nn.BatchNorm2d(mid_channels) 76 | # right 77 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 78 | kernel_size=1, bias=False) 79 | self.bn3 = nn.BatchNorm2d(mid_channels) 80 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 81 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 82 | self.bn4 = nn.BatchNorm2d(mid_channels) 83 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 84 | kernel_size=1, bias=False) 85 | self.bn5 = nn.BatchNorm2d(mid_channels) 86 | 87 | self.shuffle = ShuffleBlock() 88 | 89 | def forward(self, x): 90 | # left 91 | out1 = self.bn1(self.conv1(x)) 92 | out1 = F.relu(self.bn2(self.conv2(out1))) 93 | # right 94 | out2 = F.relu(self.bn3(self.conv3(x))) 95 | out2 = self.bn4(self.conv4(out2)) 96 | out2 = F.relu(self.bn5(self.conv5(out2))) 97 | # concat 98 | out = torch.cat([out1, out2], 1) 99 | out = self.shuffle(out) 100 | return out 101 | 102 | 103 | class ShuffleNetV2(nn.Module): 104 | def __init__(self, net_size, num_classes=10): 105 | super(ShuffleNetV2, self).__init__() 106 | out_channels = configs[net_size]['out_channels'] 107 | num_blocks = configs[net_size]['num_blocks'] 108 | 109 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 110 | # stride=1, padding=1, bias=False) 111 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 112 | self.bn1 = nn.BatchNorm2d(24) 113 | self.in_channels = 24 114 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 115 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 116 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 117 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 118 | kernel_size=1, stride=1, padding=0, bias=False) 119 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 120 | self.linear = nn.Linear(out_channels[3], num_classes) 121 | 122 | def _make_layer(self, out_channels, num_blocks): 123 | layers = [DownBlock(self.in_channels, out_channels)] 124 | for i in range(num_blocks): 125 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) 126 | self.in_channels = out_channels 127 | return nn.Sequential(*layers) 128 | 129 | def get_feat_modules(self): 130 | feat_m = nn.ModuleList([]) 131 | feat_m.append(self.conv1) 132 | feat_m.append(self.bn1) 133 | feat_m.append(self.layer1) 134 | feat_m.append(self.layer2) 135 | feat_m.append(self.layer3) 136 | return feat_m 137 | 138 | def get_bn_before_relu(self): 139 | raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher') 140 | 141 | def forward(self, x, return_features=False): 142 | out = F.relu(self.bn1(self.conv1(x))) 143 | out, f1_pre = self.layer1(out) 144 | out, f2_pre = self.layer2(out) 145 | out, f3_pre = self.layer3(out) 146 | out = F.relu(self.bn2(self.conv2(out))) 147 | out = F.avg_pool2d(out, 4) 148 | features = out.view(out.size(0), -1) 149 | out = self.linear(features) 150 | if return_features: 151 | return out, features 152 | else: 153 | return out 154 | 155 | configs = { 156 | 0.2: { 157 | 'out_channels': (40, 80, 160, 512), 158 | 'num_blocks': (3, 3, 3) 159 | }, 160 | 161 | 0.3: { 162 | 'out_channels': (40, 80, 160, 512), 163 | 'num_blocks': (3, 7, 3) 164 | }, 165 | 166 | 0.5: { 167 | 'out_channels': (48, 96, 192, 1024), 168 | 'num_blocks': (3, 7, 3) 169 | }, 170 | 171 | 1: { 172 | 'out_channels': (116, 232, 464, 1024), 173 | 'num_blocks': (3, 7, 3) 174 | }, 175 | 1.5: { 176 | 'out_channels': (176, 352, 704, 1024), 177 | 'num_blocks': (3, 7, 3) 178 | }, 179 | 2: { 180 | 'out_channels': (224, 488, 976, 2048), 181 | 'num_blocks': (3, 7, 3) 182 | } 183 | } 184 | 185 | 186 | def shuffle_v2(num_classes): 187 | model = ShuffleNetV2(net_size=1, num_classes=num_classes) 188 | return model -------------------------------------------------------------------------------- /datafree/models/classifiers/vgg.py: -------------------------------------------------------------------------------- 1 | """https://github.com/HobbitLong/RepDistiller/blob/master/models/vgg.py 2 | """ 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | __all__ = [ 9 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 10 | 'vgg19_bn', 'vgg19', 11 | ] 12 | 13 | 14 | model_urls = { 15 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 16 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 17 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 18 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 19 | } 20 | 21 | 22 | class VGG(nn.Module): 23 | 24 | def __init__(self, cfg, batch_norm=False, num_classes=1000): 25 | super(VGG, self).__init__() 26 | self.block0 = self._make_layers(cfg[0], batch_norm, 3) 27 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) 28 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) 29 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) 30 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) 31 | 32 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 33 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 34 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 35 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 36 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 37 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 38 | 39 | self.classifier = nn.Linear(512, num_classes) 40 | self._initialize_weights() 41 | 42 | def get_feat_modules(self): 43 | feat_m = nn.ModuleList([]) 44 | feat_m.append(self.block0) 45 | feat_m.append(self.pool0) 46 | feat_m.append(self.block1) 47 | feat_m.append(self.pool1) 48 | feat_m.append(self.block2) 49 | feat_m.append(self.pool2) 50 | feat_m.append(self.block3) 51 | feat_m.append(self.pool3) 52 | feat_m.append(self.block4) 53 | feat_m.append(self.pool4) 54 | return feat_m 55 | 56 | def get_bn_before_relu(self): 57 | bn1 = self.block1[-1] 58 | bn2 = self.block2[-1] 59 | bn3 = self.block3[-1] 60 | bn4 = self.block4[-1] 61 | return [bn1, bn2, bn3, bn4] 62 | 63 | def forward(self, x, return_features=False): 64 | h = x.shape[2] 65 | x = F.relu(self.block0(x)) 66 | x = self.pool0(x) 67 | x = self.block1(x) 68 | x = F.relu(x) 69 | x = self.pool1(x) 70 | x = self.block2(x) 71 | x = F.relu(x) 72 | x = self.pool2(x) 73 | x = self.block3(x) 74 | x = F.relu(x) 75 | if h == 64: 76 | x = self.pool3(x) 77 | x = self.block4(x) 78 | x = F.relu(x) 79 | x = self.pool4(x) 80 | features = x.view(x.size(0), -1) 81 | x = self.classifier(features) 82 | if return_features: 83 | return x, features 84 | else: 85 | return x 86 | 87 | @staticmethod 88 | def _make_layers(cfg, batch_norm=False, in_channels=3): 89 | layers = [] 90 | for v in cfg: 91 | if v == 'M': 92 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 93 | else: 94 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 95 | if batch_norm: 96 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 97 | else: 98 | layers += [conv2d, nn.ReLU(inplace=True)] 99 | in_channels = v 100 | layers = layers[:-1] 101 | return nn.Sequential(*layers) 102 | 103 | def _initialize_weights(self): 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 107 | m.weight.data.normal_(0, math.sqrt(2. / n)) 108 | if m.bias is not None: 109 | m.bias.data.zero_() 110 | elif isinstance(m, nn.BatchNorm2d): 111 | m.weight.data.fill_(1) 112 | m.bias.data.zero_() 113 | elif isinstance(m, nn.Linear): 114 | n = m.weight.size(1) 115 | m.weight.data.normal_(0, 0.01) 116 | m.bias.data.zero_() 117 | 118 | 119 | cfg = { 120 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]], 121 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], 122 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], 123 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]], 124 | 'S': [[64], [128], [256], [512], [512]], 125 | } 126 | 127 | 128 | def vgg8(**kwargs): 129 | """VGG 8-layer model (configuration "S") 130 | Args: 131 | pretrained (bool): If True, returns a model pre-trained on ImageNet 132 | """ 133 | model = VGG(cfg['S'], **kwargs) 134 | return model 135 | 136 | 137 | def vgg8_bn(**kwargs): 138 | """VGG 8-layer model (configuration "S") 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | """ 142 | model = VGG(cfg['S'], batch_norm=True, **kwargs) 143 | return model 144 | 145 | 146 | def vgg11(**kwargs): 147 | """VGG 11-layer model (configuration "A") 148 | Args: 149 | pretrained (bool): If True, returns a model pre-trained on ImageNet 150 | """ 151 | model = VGG(cfg['A'], **kwargs) 152 | return model 153 | 154 | 155 | def vgg11_bn(**kwargs): 156 | """VGG 11-layer model (configuration "A") with batch normalization""" 157 | model = VGG(cfg['A'], batch_norm=True, **kwargs) 158 | return model 159 | 160 | 161 | def vgg13(**kwargs): 162 | """VGG 13-layer model (configuration "B") 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | model = VGG(cfg['B'], **kwargs) 167 | return model 168 | 169 | 170 | def vgg13_bn(**kwargs): 171 | """VGG 13-layer model (configuration "B") with batch normalization""" 172 | model = VGG(cfg['B'], batch_norm=True, **kwargs) 173 | return model 174 | 175 | 176 | def vgg16(**kwargs): 177 | """VGG 16-layer model (configuration "D") 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = VGG(cfg['D'], **kwargs) 182 | return model 183 | 184 | 185 | def vgg16_bn(**kwargs): 186 | """VGG 16-layer model (configuration "D") with batch normalization""" 187 | model = VGG(cfg['D'], batch_norm=True, **kwargs) 188 | return model 189 | 190 | 191 | def vgg19(**kwargs): 192 | """VGG 19-layer model (configuration "E") 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = VGG(cfg['E'], **kwargs) 197 | return model 198 | 199 | 200 | def vgg19_bn(**kwargs): 201 | """VGG 19-layer model (configuration 'E') with batch normalization""" 202 | model = VGG(cfg['E'], batch_norm=True, **kwargs) 203 | return model 204 | 205 | 206 | if __name__ == '__main__': 207 | import torch 208 | 209 | x = torch.randn(2, 3, 32, 32) 210 | net = vgg19_bn(num_classes=100) 211 | feats, logit = net(x, is_feat=True, preact=True) 212 | 213 | for f in feats: 214 | print(f.shape, f.min().item()) 215 | print(logit.shape) 216 | 217 | for m in net.get_bn_before_relu(): 218 | if isinstance(m, nn.BatchNorm2d): 219 | print('pass') 220 | else: 221 | print('warning') -------------------------------------------------------------------------------- /datafree/models/classifiers/wresnet.py: -------------------------------------------------------------------------------- 1 | '''https://github.com/polo5/ZeroShotKnowledgeTransfer/blob/master/models/wresnet.py 2 | ''' 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | """ 10 | Original Author: Wei Yang 11 | """ 12 | 13 | __all__ = ['wrn'] 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | def __init__(self, in_planes, out_planes, stride, dropout_rate=0.0): 18 | super(BasicBlock, self).__init__() 19 | self.bn1 = nn.BatchNorm2d(in_planes) 20 | self.relu1 = nn.ReLU(inplace=True) 21 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(out_planes) 24 | self.relu2 = nn.ReLU(inplace=True) 25 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 26 | padding=1, bias=False) 27 | self.dropout = nn.Dropout( dropout_rate ) 28 | self.equalInOut = (in_planes == out_planes) 29 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 30 | padding=0, bias=False) or None 31 | 32 | def forward(self, x): 33 | if not self.equalInOut: 34 | x = self.relu1(self.bn1(x)) 35 | else: 36 | out = self.relu1(self.bn1(x)) 37 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 38 | out = self.dropout(out) 39 | out = self.conv2(out) 40 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 41 | 42 | 43 | class NetworkBlock(nn.Module): 44 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropout_rate=0.0): 45 | super(NetworkBlock, self).__init__() 46 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropout_rate) 47 | 48 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropout_rate): 49 | layers = [] 50 | for i in range(nb_layers): 51 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropout_rate)) 52 | return nn.Sequential(*layers) 53 | 54 | def forward(self, x): 55 | return self.layer(x) 56 | 57 | 58 | class WideResNet(nn.Module): 59 | def __init__(self, depth, num_classes, widen_factor=1, dropout_rate=0.0): 60 | super(WideResNet, self).__init__() 61 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 62 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 63 | n = (depth - 4) // 6 64 | block = BasicBlock 65 | # 1st conv before any network block 66 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 67 | padding=1, bias=False) 68 | # 1st block 69 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropout_rate) 70 | # 2nd block 71 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropout_rate) 72 | # 3rd block 73 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropout_rate) 74 | # global average pooling and classifier 75 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.fc = nn.Linear(nChannels[3], num_classes) 78 | self.nChannels = nChannels[3] 79 | 80 | for m in self.modules(): 81 | if isinstance(m, nn.Conv2d): 82 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 83 | m.weight.data.normal_(0, math.sqrt(2. / n)) 84 | elif isinstance(m, nn.BatchNorm2d): 85 | m.weight.data.fill_(1) 86 | m.bias.data.zero_() 87 | elif isinstance(m, nn.Linear): 88 | m.bias.data.zero_() 89 | 90 | def forward(self, x, return_features=False): 91 | out = self.conv1(x) 92 | x1 = self.block1(out) 93 | x2 = self.block2(x1) 94 | x3 = self.block3(x2) 95 | out = self.relu(self.bn1(x3)) 96 | out = F.adaptive_avg_pool2d(out, (1,1)) 97 | features = out.view(-1, self.nChannels) 98 | out = self.fc(features) 99 | 100 | if return_features: 101 | return out, features, [x1, x2, x3] 102 | else: 103 | return out 104 | 105 | def wrn_16_1(num_classes, dropout_rate=0): 106 | return WideResNet(depth=16, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate) 107 | 108 | def wrn_16_2(num_classes, dropout_rate=0): 109 | return WideResNet(depth=16, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate) 110 | 111 | def wrn_40_1(num_classes, dropout_rate=0): 112 | return WideResNet(depth=40, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate) 113 | 114 | def wrn_40_2(num_classes, dropout_rate=0): 115 | return WideResNet(depth=40, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate) -------------------------------------------------------------------------------- /datafree/models/deeplab/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling import * 2 | from ._deeplab import convert_to_separable_conv -------------------------------------------------------------------------------- /datafree/models/deeplab/_deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .utils import _SimpleSegmentationModel 6 | 7 | 8 | __all__ = ["DeepLabV3"] 9 | 10 | 11 | class DeepLabV3(_SimpleSegmentationModel): 12 | """ 13 | Implements DeepLabV3 model from 14 | `"Rethinking Atrous Convolution for Semantic Image Segmentation" 15 | `_. 16 | 17 | Arguments: 18 | backbone (nn.Module): the network used to compute the features for the model. 19 | The backbone should return an OrderedDict[Tensor], with the key being 20 | "out" for the last feature map used, and "aux" if an auxiliary classifier 21 | is used. 22 | classifier (nn.Module): module that takes the "out" element returned from 23 | the backbone and returns a dense prediction. 24 | aux_classifier (nn.Module, optional): auxiliary classifier used during training 25 | """ 26 | pass 27 | 28 | class DeepLabHeadV3Plus(nn.Module): 29 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]): 30 | super(DeepLabHeadV3Plus, self).__init__() 31 | self.project = nn.Sequential( 32 | nn.Conv2d(low_level_channels, 48, 1, bias=False), 33 | nn.BatchNorm2d(48), 34 | nn.ReLU(inplace=True), 35 | ) 36 | 37 | self.aspp = ASPP(in_channels, aspp_dilate) 38 | 39 | self.classifier = nn.Sequential( 40 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 41 | nn.BatchNorm2d(256), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(256, num_classes, 1) 44 | ) 45 | self._init_weight() 46 | 47 | def forward(self, feature): 48 | low_level_feature = self.project( feature['low_level'] ) 49 | output_feature = self.aspp(feature['out']) 50 | output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) 51 | return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) ) 52 | 53 | def _init_weight(self): 54 | for m in self.modules(): 55 | if isinstance(m, nn.Conv2d): 56 | nn.init.kaiming_normal_(m.weight) 57 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 58 | nn.init.constant_(m.weight, 1) 59 | nn.init.constant_(m.bias, 0) 60 | 61 | class DeepLabHead(nn.Module): 62 | def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]): 63 | super(DeepLabHead, self).__init__() 64 | 65 | self.classifier = nn.Sequential( 66 | ASPP(in_channels, aspp_dilate), 67 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 68 | nn.BatchNorm2d(256), 69 | nn.ReLU(inplace=True), 70 | nn.Conv2d(256, num_classes, 1) 71 | ) 72 | self._init_weight() 73 | 74 | def forward(self, feature): 75 | return self.classifier( feature['out'] ) 76 | 77 | def _init_weight(self): 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | nn.init.kaiming_normal_(m.weight) 81 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 82 | nn.init.constant_(m.weight, 1) 83 | nn.init.constant_(m.bias, 0) 84 | 85 | class AtrousSeparableConvolution(nn.Module): 86 | """ Atrous Separable Convolution 87 | """ 88 | def __init__(self, in_channels, out_channels, kernel_size, 89 | stride=1, padding=0, dilation=1, bias=True): 90 | super(AtrousSeparableConvolution, self).__init__() 91 | self.body = nn.Sequential( 92 | # Separable Conv 93 | nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ), 94 | # PointWise Conv 95 | nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), 96 | ) 97 | 98 | self._init_weight() 99 | 100 | def forward(self, x): 101 | return self.body(x) 102 | 103 | def _init_weight(self): 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | nn.init.kaiming_normal_(m.weight) 107 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 108 | nn.init.constant_(m.weight, 1) 109 | nn.init.constant_(m.bias, 0) 110 | 111 | class ASPPConv(nn.Sequential): 112 | def __init__(self, in_channels, out_channels, dilation): 113 | modules = [ 114 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 115 | nn.BatchNorm2d(out_channels), 116 | nn.ReLU(inplace=True) 117 | ] 118 | super(ASPPConv, self).__init__(*modules) 119 | 120 | class ASPPPooling(nn.Sequential): 121 | def __init__(self, in_channels, out_channels): 122 | super(ASPPPooling, self).__init__( 123 | nn.AdaptiveAvgPool2d(1), 124 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 125 | nn.BatchNorm2d(out_channels), 126 | nn.ReLU(inplace=True)) 127 | 128 | def forward(self, x): 129 | size = x.shape[-2:] 130 | x = super(ASPPPooling, self).forward(x) 131 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 132 | 133 | class ASPP(nn.Module): 134 | def __init__(self, in_channels, atrous_rates): 135 | super(ASPP, self).__init__() 136 | out_channels = 256 137 | modules = [] 138 | modules.append(nn.Sequential( 139 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 140 | nn.BatchNorm2d(out_channels), 141 | nn.ReLU(inplace=True))) 142 | 143 | rate1, rate2, rate3 = tuple(atrous_rates) 144 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 145 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 146 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 147 | modules.append(ASPPPooling(in_channels, out_channels)) 148 | 149 | self.convs = nn.ModuleList(modules) 150 | 151 | self.project = nn.Sequential( 152 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 153 | nn.BatchNorm2d(out_channels), 154 | nn.ReLU(inplace=True), 155 | nn.Dropout(0.1),) 156 | 157 | def forward(self, x): 158 | res = [] 159 | for conv in self.convs: 160 | res.append(conv(x)) 161 | res = torch.cat(res, dim=1) 162 | return self.project(res) 163 | 164 | 165 | 166 | def convert_to_separable_conv(module): 167 | new_module = module 168 | if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1: 169 | new_module = AtrousSeparableConvolution(module.in_channels, 170 | module.out_channels, 171 | module.kernel_size, 172 | module.stride, 173 | module.padding, 174 | module.dilation, 175 | module.bias) 176 | for name, child in module.named_children(): 177 | new_module.add_module(name, convert_to_separable_conv(child)) 178 | return new_module -------------------------------------------------------------------------------- /datafree/models/deeplab/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet 2 | from . import mobilenetv2 3 | -------------------------------------------------------------------------------- /datafree/models/deeplab/backbone/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.models.utils import load_state_dict_from_url 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 6 | 7 | 8 | model_urls = { 9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 10 | } 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | class ConvBNReLU(nn.Sequential): 34 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1): 35 | #padding = (kernel_size - 1) // 2 36 | super(ConvBNReLU, self).__init__( 37 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False), 38 | nn.BatchNorm2d(out_planes), 39 | nn.ReLU6(inplace=True) 40 | ) 41 | 42 | def fixed_padding(kernel_size, dilation): 43 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 44 | pad_total = kernel_size_effective - 1 45 | pad_beg = pad_total // 2 46 | pad_end = pad_total - pad_beg 47 | return (pad_beg, pad_end, pad_beg, pad_end) 48 | 49 | class InvertedResidual(nn.Module): 50 | def __init__(self, inp, oup, stride, dilation, expand_ratio): 51 | super(InvertedResidual, self).__init__() 52 | self.stride = stride 53 | assert stride in [1, 2] 54 | 55 | hidden_dim = int(round(inp * expand_ratio)) 56 | self.use_res_connect = self.stride == 1 and inp == oup 57 | 58 | layers = [] 59 | if expand_ratio != 1: 60 | # pw 61 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 62 | 63 | layers.extend([ 64 | # dw 65 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim), 66 | # pw-linear 67 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 68 | nn.BatchNorm2d(oup), 69 | ]) 70 | self.conv = nn.Sequential(*layers) 71 | 72 | self.input_padding = fixed_padding( 3, dilation ) 73 | 74 | def forward(self, x): 75 | x_pad = F.pad(x, self.input_padding) 76 | if self.use_res_connect: 77 | return x + self.conv(x_pad) 78 | else: 79 | return self.conv(x_pad) 80 | 81 | class MobileNetV2(nn.Module): 82 | def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): 83 | """ 84 | MobileNet V2 main class 85 | 86 | Args: 87 | num_classes (int): Number of classes 88 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 89 | inverted_residual_setting: Network structure 90 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 91 | Set to 1 to turn off rounding 92 | """ 93 | super(MobileNetV2, self).__init__() 94 | block = InvertedResidual 95 | input_channel = 32 96 | last_channel = 1280 97 | self.output_stride = output_stride 98 | current_stride = 1 99 | if inverted_residual_setting is None: 100 | inverted_residual_setting = [ 101 | # t, c, n, s 102 | [1, 16, 1, 1], 103 | [6, 24, 2, 2], 104 | [6, 32, 3, 2], 105 | [6, 64, 4, 2], 106 | [6, 96, 3, 1], 107 | [6, 160, 3, 2], 108 | [6, 320, 1, 1], 109 | ] 110 | 111 | # only check the first element, assuming user knows t,c,n,s are required 112 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 113 | raise ValueError("inverted_residual_setting should be non-empty " 114 | "or a 4-element list, got {}".format(inverted_residual_setting)) 115 | 116 | # building first layer 117 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 118 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 119 | features = [ConvBNReLU(3, input_channel, stride=2)] 120 | current_stride *= 2 121 | dilation=1 122 | previous_dilation = 1 123 | 124 | # building inverted residual blocks 125 | for t, c, n, s in inverted_residual_setting: 126 | output_channel = _make_divisible(c * width_mult, round_nearest) 127 | previous_dilation = dilation 128 | if current_stride == output_stride: 129 | stride = 1 130 | dilation *= s 131 | else: 132 | stride = s 133 | current_stride *= s 134 | output_channel = int(c * width_mult) 135 | 136 | for i in range(n): 137 | if i==0: 138 | features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t)) 139 | else: 140 | features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t)) 141 | input_channel = output_channel 142 | # building last several layers 143 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 144 | # make it nn.Sequential 145 | self.features = nn.Sequential(*features) 146 | 147 | # building classifier 148 | self.classifier = nn.Sequential( 149 | nn.Dropout(0.2), 150 | nn.Linear(self.last_channel, num_classes), 151 | ) 152 | 153 | # weight initialization 154 | for m in self.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 157 | if m.bias is not None: 158 | nn.init.zeros_(m.bias) 159 | elif isinstance(m, nn.BatchNorm2d): 160 | nn.init.ones_(m.weight) 161 | nn.init.zeros_(m.bias) 162 | elif isinstance(m, nn.Linear): 163 | nn.init.normal_(m.weight, 0, 0.01) 164 | nn.init.zeros_(m.bias) 165 | 166 | def forward(self, x): 167 | x = self.features(x) 168 | x = x.mean([2, 3]) 169 | x = self.classifier(x) 170 | return x 171 | 172 | 173 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 174 | """ 175 | Constructs a MobileNetV2 architecture from 176 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 177 | 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | progress (bool): If True, displays a progress bar of the download to stderr 181 | """ 182 | model = MobileNetV2(**kwargs) 183 | if pretrained: 184 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 185 | progress=progress) 186 | model.load_state_dict(state_dict) 187 | return model 188 | -------------------------------------------------------------------------------- /datafree/models/deeplab/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 39 | base_width=64, dilation=1, norm_layer=None): 40 | super(BasicBlock, self).__init__() 41 | if norm_layer is None: 42 | norm_layer = nn.BatchNorm2d 43 | if groups != 1 or base_width != 64: 44 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 45 | if dilation > 1: 46 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 48 | self.conv1 = conv3x3(inplanes, planes, stride) 49 | self.bn1 = norm_layer(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3(planes, planes) 52 | self.bn2 = norm_layer(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | identity = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | identity = self.downsample(x) 68 | 69 | out += identity 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | expansion = 4 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 79 | base_width=64, dilation=1, norm_layer=None): 80 | super(Bottleneck, self).__init__() 81 | if norm_layer is None: 82 | norm_layer = nn.BatchNorm2d 83 | width = int(planes * (base_width / 64.)) * groups 84 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 85 | self.conv1 = conv1x1(inplanes, width) 86 | self.bn1 = norm_layer(width) 87 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 88 | self.bn2 = norm_layer(width) 89 | self.conv3 = conv1x1(width, planes * self.expansion) 90 | self.bn3 = norm_layer(planes * self.expansion) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | identity = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | identity = self.downsample(x) 111 | 112 | out += identity 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class ResNet(nn.Module): 119 | 120 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 121 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 122 | norm_layer=None): 123 | super(ResNet, self).__init__() 124 | if norm_layer is None: 125 | norm_layer = nn.BatchNorm2d 126 | self._norm_layer = norm_layer 127 | 128 | self.inplanes = 64 129 | self.dilation = 1 130 | if replace_stride_with_dilation is None: 131 | # each element in the tuple indicates if we should replace 132 | # the 2x2 stride with a dilated convolution instead 133 | replace_stride_with_dilation = [False, False, False] 134 | if len(replace_stride_with_dilation) != 3: 135 | raise ValueError("replace_stride_with_dilation should be None " 136 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 137 | self.groups = groups 138 | self.base_width = width_per_group 139 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 140 | bias=False) 141 | self.bn1 = norm_layer(self.inplanes) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 144 | self.layer1 = self._make_layer(block, 64, layers[0]) 145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 146 | dilate=replace_stride_with_dilation[0]) 147 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 148 | dilate=replace_stride_with_dilation[1]) 149 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 150 | dilate=replace_stride_with_dilation[2]) 151 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 152 | self.fc = nn.Linear(512 * block.expansion, num_classes) 153 | 154 | for m in self.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 157 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 158 | nn.init.constant_(m.weight, 1) 159 | nn.init.constant_(m.bias, 0) 160 | 161 | # Zero-initialize the last BN in each residual branch, 162 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 163 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 164 | if zero_init_residual: 165 | for m in self.modules(): 166 | if isinstance(m, Bottleneck): 167 | nn.init.constant_(m.bn3.weight, 0) 168 | elif isinstance(m, BasicBlock): 169 | nn.init.constant_(m.bn2.weight, 0) 170 | 171 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 172 | norm_layer = self._norm_layer 173 | downsample = None 174 | previous_dilation = self.dilation 175 | if dilate: 176 | self.dilation *= stride 177 | stride = 1 178 | if stride != 1 or self.inplanes != planes * block.expansion: 179 | downsample = nn.Sequential( 180 | conv1x1(self.inplanes, planes * block.expansion, stride), 181 | norm_layer(planes * block.expansion), 182 | ) 183 | 184 | layers = [] 185 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 186 | self.base_width, previous_dilation, norm_layer)) 187 | self.inplanes = planes * block.expansion 188 | for _ in range(1, blocks): 189 | layers.append(block(self.inplanes, planes, groups=self.groups, 190 | base_width=self.base_width, dilation=self.dilation, 191 | norm_layer=norm_layer)) 192 | 193 | return nn.Sequential(*layers) 194 | 195 | def forward(self, x): 196 | x = self.conv1(x) 197 | x = self.bn1(x) 198 | x = self.relu(x) 199 | x = self.maxpool(x) 200 | 201 | x = self.layer1(x) 202 | x = self.layer2(x) 203 | x = self.layer3(x) 204 | x = self.layer4(x) 205 | 206 | x = self.avgpool(x) 207 | x = torch.flatten(x, 1) 208 | x = self.fc(x) 209 | 210 | return x 211 | 212 | 213 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 214 | model = ResNet(block, layers, **kwargs) 215 | if pretrained: 216 | state_dict = load_state_dict_from_url(model_urls[arch], 217 | progress=progress) 218 | model.load_state_dict(state_dict) 219 | return model 220 | 221 | 222 | def resnet18(pretrained=False, progress=True, **kwargs): 223 | r"""ResNet-18 model from 224 | `"Deep Residual Learning for Image Recognition" `_ 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | progress (bool): If True, displays a progress bar of the download to stderr 229 | """ 230 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 231 | **kwargs) 232 | 233 | 234 | def resnet34(pretrained=False, progress=True, **kwargs): 235 | r"""ResNet-34 model from 236 | `"Deep Residual Learning for Image Recognition" `_ 237 | 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | progress (bool): If True, displays a progress bar of the download to stderr 241 | """ 242 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 243 | **kwargs) 244 | 245 | 246 | def resnet50(pretrained=False, progress=True, **kwargs): 247 | r"""ResNet-50 model from 248 | `"Deep Residual Learning for Image Recognition" `_ 249 | 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet101(pretrained=False, progress=True, **kwargs): 259 | r"""ResNet-101 model from 260 | `"Deep Residual Learning for Image Recognition" `_ 261 | 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | progress (bool): If True, displays a progress bar of the download to stderr 265 | """ 266 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 267 | **kwargs) 268 | 269 | 270 | def resnet152(pretrained=False, progress=True, **kwargs): 271 | r"""ResNet-152 model from 272 | `"Deep Residual Learning for Image Recognition" `_ 273 | 274 | Args: 275 | pretrained (bool): If True, returns a model pre-trained on ImageNet 276 | progress (bool): If True, displays a progress bar of the download to stderr 277 | """ 278 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 279 | **kwargs) 280 | 281 | 282 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 283 | r"""ResNeXt-50 32x4d model from 284 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 285 | 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | progress (bool): If True, displays a progress bar of the download to stderr 289 | """ 290 | kwargs['groups'] = 32 291 | kwargs['width_per_group'] = 4 292 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 293 | pretrained, progress, **kwargs) 294 | 295 | 296 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 297 | r"""ResNeXt-101 32x8d model from 298 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 299 | 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | kwargs['groups'] = 32 305 | kwargs['width_per_group'] = 8 306 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 307 | pretrained, progress, **kwargs) 308 | 309 | 310 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 311 | r"""Wide ResNet-50-2 model from 312 | `"Wide Residual Networks" `_ 313 | 314 | The model is the same as ResNet except for the bottleneck number of channels 315 | which is twice larger in every block. The number of channels in outer 1x1 316 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 317 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 318 | 319 | Args: 320 | pretrained (bool): If True, returns a model pre-trained on ImageNet 321 | progress (bool): If True, displays a progress bar of the download to stderr 322 | """ 323 | kwargs['width_per_group'] = 64 * 2 324 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 325 | pretrained, progress, **kwargs) 326 | 327 | 328 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 329 | r"""Wide ResNet-101-2 model from 330 | `"Wide Residual Networks" `_ 331 | 332 | The model is the same as ResNet except for the bottleneck number of channels 333 | which is twice larger in every block. The number of channels in outer 1x1 334 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 335 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 336 | 337 | Args: 338 | pretrained (bool): If True, returns a model pre-trained on ImageNet 339 | progress (bool): If True, displays a progress bar of the download to stderr 340 | """ 341 | kwargs['width_per_group'] = 64 * 2 342 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 343 | pretrained, progress, **kwargs) 344 | -------------------------------------------------------------------------------- /datafree/models/deeplab/modeling.py: -------------------------------------------------------------------------------- 1 | from .utils import IntermediateLayerGetter 2 | from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3 3 | from .backbone import resnet 4 | from .backbone import mobilenetv2 5 | 6 | def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone): 7 | 8 | if output_stride==8: 9 | replace_stride_with_dilation=[False, True, True] 10 | aspp_dilate = [12, 24, 36] 11 | else: 12 | replace_stride_with_dilation=[False, False, True] 13 | aspp_dilate = [6, 12, 18] 14 | 15 | backbone = resnet.__dict__[backbone_name]( 16 | pretrained=pretrained_backbone, 17 | replace_stride_with_dilation=replace_stride_with_dilation) 18 | 19 | inplanes = 2048 20 | low_level_planes = 256 21 | 22 | if name=='deeplabv3plus': 23 | return_layers = {'layer4': 'out', 'layer1': 'low_level'} 24 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) 25 | elif name=='deeplabv3': 26 | return_layers = {'layer4': 'out'} 27 | classifier = DeepLabHead(inplanes , num_classes, aspp_dilate) 28 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 29 | 30 | model = DeepLabV3(backbone, classifier) 31 | return model 32 | 33 | def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_backbone): 34 | if output_stride==8: 35 | aspp_dilate = [12, 24, 36] 36 | else: 37 | aspp_dilate = [6, 12, 18] 38 | 39 | backbone = mobilenetv2.mobilenet_v2(pretrained=pretrained_backbone, output_stride=output_stride) 40 | 41 | # rename layers 42 | backbone.low_level_features = backbone.features[0:4] 43 | backbone.high_level_features = backbone.features[4:-1] 44 | backbone.features = None 45 | backbone.classifier = None 46 | 47 | inplanes = 320 48 | low_level_planes = 24 49 | 50 | if name=='deeplabv3plus': 51 | return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'} 52 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) 53 | elif name=='deeplabv3': 54 | return_layers = {'high_level_features': 'out'} 55 | classifier = DeepLabHead(inplanes , num_classes, aspp_dilate) 56 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 57 | 58 | model = DeepLabV3(backbone, classifier) 59 | return model 60 | 61 | def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone): 62 | 63 | if backbone=='mobilenetv2': 64 | model = _segm_mobilenet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 65 | elif backbone.startswith('resnet'): 66 | model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 67 | else: 68 | raise NotImplementedError 69 | return model 70 | 71 | 72 | # Deeplab v3 73 | 74 | def deeplabv3_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True): 75 | """Constructs a DeepLabV3 model with a ResNet-50 backbone. 76 | 77 | Args: 78 | num_classes (int): number of classes. 79 | output_stride (int): output stride for deeplab. 80 | pretrained_backbone (bool): If True, use the pretrained backbone. 81 | """ 82 | return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 83 | 84 | def deeplabv3_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True): 85 | """Constructs a DeepLabV3 model with a ResNet-101 backbone. 86 | 87 | Args: 88 | num_classes (int): number of classes. 89 | output_stride (int): output stride for deeplab. 90 | pretrained_backbone (bool): If True, use the pretrained backbone. 91 | """ 92 | return _load_model('deeplabv3', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 93 | 94 | def deeplabv3_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True, **kwargs): 95 | """Constructs a DeepLabV3 model with a MobileNetv2 backbone. 96 | 97 | Args: 98 | num_classes (int): number of classes. 99 | output_stride (int): output stride for deeplab. 100 | pretrained_backbone (bool): If True, use the pretrained backbone. 101 | """ 102 | return _load_model('deeplabv3', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 103 | 104 | 105 | # Deeplab v3+ 106 | 107 | def deeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True): 108 | """Constructs a DeepLabV3 model with a ResNet-50 backbone. 109 | 110 | Args: 111 | num_classes (int): number of classes. 112 | output_stride (int): output stride for deeplab. 113 | pretrained_backbone (bool): If True, use the pretrained backbone. 114 | """ 115 | return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 116 | 117 | 118 | def deeplabv3plus_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True): 119 | """Constructs a DeepLabV3+ model with a ResNet-101 backbone. 120 | 121 | Args: 122 | num_classes (int): number of classes. 123 | output_stride (int): output stride for deeplab. 124 | pretrained_backbone (bool): If True, use the pretrained backbone. 125 | """ 126 | return _load_model('deeplabv3plus', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 127 | 128 | 129 | def deeplabv3plus_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True): 130 | """Constructs a DeepLabV3+ model with a MobileNetv2 backbone. 131 | 132 | Args: 133 | num_classes (int): number of classes. 134 | output_stride (int): output stride for deeplab. 135 | pretrained_backbone (bool): If True, use the pretrained backbone. 136 | """ 137 | return _load_model('deeplabv3plus', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) -------------------------------------------------------------------------------- /datafree/models/deeplab/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | 7 | class _SimpleSegmentationModel(nn.Module): 8 | def __init__(self, backbone, classifier): 9 | super(_SimpleSegmentationModel, self).__init__() 10 | self.backbone = backbone 11 | self.classifier = classifier 12 | 13 | def forward(self, x): 14 | input_shape = x.shape[-2:] 15 | features = self.backbone(x) 16 | x = self.classifier(features) 17 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 18 | return x 19 | 20 | 21 | class IntermediateLayerGetter(nn.ModuleDict): 22 | """ 23 | Module wrapper that returns intermediate layers from a model 24 | 25 | It has a strong assumption that the modules have been registered 26 | into the model in the same order as they are used. 27 | This means that one should **not** reuse the same nn.Module 28 | twice in the forward if you want this to work. 29 | 30 | Additionally, it is only able to query submodules that are directly 31 | assigned to the model. So if `model` is passed, `model.feature1` can 32 | be returned, but not `model.feature1.layer2`. 33 | 34 | Arguments: 35 | model (nn.Module): model on which we will extract the features 36 | return_layers (Dict[name, new_name]): a dict containing the names 37 | of the modules for which the activations will be returned as 38 | the key of the dict, and the value of the dict is the name 39 | of the returned activation (which the user can specify). 40 | 41 | Examples:: 42 | 43 | >>> m = torchvision.models.resnet18(pretrained=True) 44 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2` 45 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, 46 | >>> {'layer1': 'feat1', 'layer3': 'feat2'}) 47 | >>> out = new_m(torch.rand(1, 3, 224, 224)) 48 | >>> print([(k, v.shape) for k, v in out.items()]) 49 | >>> [('feat1', torch.Size([1, 64, 56, 56])), 50 | >>> ('feat2', torch.Size([1, 256, 14, 14]))] 51 | """ 52 | def __init__(self, model, return_layers): 53 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 54 | raise ValueError("return_layers are not present in model") 55 | 56 | orig_return_layers = return_layers 57 | return_layers = {k: v for k, v in return_layers.items()} 58 | layers = OrderedDict() 59 | for name, module in model.named_children(): 60 | layers[name] = module 61 | if name in return_layers: 62 | del return_layers[name] 63 | if not return_layers: 64 | break 65 | 66 | super(IntermediateLayerGetter, self).__init__(layers) 67 | self.return_layers = orig_return_layers 68 | 69 | def forward(self, x): 70 | out = OrderedDict() 71 | for name, module in self.named_children(): 72 | x = module(x) 73 | if name in self.return_layers: 74 | out_name = self.return_layers[name] 75 | out[out_name] = x 76 | return out 77 | -------------------------------------------------------------------------------- /datafree/models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Flatten(nn.Module): 6 | def __init__(self): 7 | super(Flatten, self).__init__() 8 | 9 | def forward(self, x): 10 | return torch.flatten(x, 1) 11 | 12 | class Generator(nn.Module): 13 | def __init__(self, nz=100, ngf=64, img_size=32, nc=3): 14 | super(Generator, self).__init__() 15 | 16 | self.init_size = img_size // 4 17 | self.l1 = nn.Sequential(nn.Linear(nz, ngf * 2 * self.init_size ** 2)) 18 | 19 | self.conv_blocks = nn.Sequential( 20 | nn.BatchNorm2d(ngf * 2), 21 | nn.Upsample(scale_factor=2), 22 | 23 | nn.Conv2d(ngf*2, ngf*2, 3, stride=1, padding=1, bias=False), 24 | nn.BatchNorm2d(ngf*2), 25 | nn.LeakyReLU(0.2, inplace=True), 26 | nn.Upsample(scale_factor=2), 27 | 28 | nn.Conv2d(ngf*2, ngf, 3, stride=1, padding=1, bias=False), 29 | nn.BatchNorm2d(ngf), 30 | nn.LeakyReLU(0.2, inplace=True), 31 | nn.Conv2d(ngf, nc, 3, stride=1, padding=1), 32 | nn.Sigmoid(), 33 | ) 34 | 35 | def forward(self, z): 36 | out = self.l1(z) 37 | out = out.view(out.shape[0], -1, self.init_size, self.init_size) 38 | img = self.conv_blocks(out) 39 | return img 40 | 41 | 42 | class LargeGenerator(nn.Module): 43 | def __init__(self, nz=100, ngf=64, img_size=32, nc=3): 44 | super(LargeGenerator, self).__init__() 45 | 46 | self.init_size = img_size // 4 47 | self.l1 = nn.Sequential(nn.Linear(nz, ngf * 4 * self.init_size ** 2)) 48 | 49 | self.conv_blocks = nn.Sequential( 50 | nn.BatchNorm2d(ngf * 4), 51 | nn.Upsample(scale_factor=2), 52 | 53 | nn.Conv2d(ngf*4, ngf*2, 3, stride=1, padding=1, bias=False), 54 | nn.BatchNorm2d(ngf*2), 55 | nn.LeakyReLU(0.2, inplace=True), 56 | nn.Upsample(scale_factor=2), 57 | 58 | nn.Conv2d(ngf*2, ngf, 3, stride=1, padding=1, bias=False), 59 | nn.BatchNorm2d(ngf), 60 | nn.LeakyReLU(0.2, inplace=True), 61 | nn.Conv2d(ngf, nc, 3, stride=1, padding=1), 62 | nn.Sigmoid(), 63 | ) 64 | 65 | def forward(self, z): 66 | out = self.l1(z) 67 | out = out.view(out.shape[0], -1, self.init_size, self.init_size) 68 | img = self.conv_blocks(out) 69 | return img 70 | 71 | 72 | class DCGAN_Generator(nn.Module): 73 | """ Generator from DCGAN: https://arxiv.org/abs/1511.06434 74 | """ 75 | def __init__(self, nz=100, ngf=64, nc=3, img_size=64, slope=0.2): 76 | super(DCGAN_Generator, self).__init__() 77 | self.nz = nz 78 | if isinstance(img_size, (list, tuple)): 79 | self.init_size = ( img_size[0]//16, img_size[1]//16 ) 80 | else: 81 | self.init_size = ( img_size // 16, img_size // 16) 82 | 83 | self.project = nn.Sequential( 84 | Flatten(), 85 | nn.Linear(nz, ngf*8*self.init_size[0]*self.init_size[1]), 86 | ) 87 | 88 | self.main = nn.Sequential( 89 | nn.BatchNorm2d(ngf*8), 90 | 91 | nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False), 92 | nn.BatchNorm2d(ngf*4), 93 | nn.LeakyReLU(slope, inplace=True), 94 | # 2x 95 | 96 | nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False), 97 | nn.BatchNorm2d(ngf*2), 98 | nn.LeakyReLU(slope, inplace=True), 99 | # 4x 100 | 101 | nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False), 102 | nn.BatchNorm2d(ngf), 103 | nn.LeakyReLU(slope, inplace=True), 104 | # 8x 105 | 106 | nn.ConvTranspose2d(ngf, ngf, 4, 2, 1, bias=False), 107 | nn.BatchNorm2d(ngf), 108 | nn.LeakyReLU(slope, inplace=True), 109 | # 16x 110 | 111 | nn.Conv2d(ngf, nc, 3, 1,1), 112 | nn.Sigmoid(), 113 | #nn.Sigmoid() 114 | ) 115 | 116 | def forward(self, z): 117 | proj = self.project(z) 118 | proj = proj.view(proj.shape[0], -1, self.init_size[0], self.init_size[1]) 119 | output = self.main(proj) 120 | return output 121 | 122 | class DCGAN_CondGenerator(nn.Module): 123 | """ Generator from DCGAN: https://arxiv.org/abs/1511.06434 124 | """ 125 | def __init__(self, num_classes, nz=100, n_emb=50, ngf=64, nc=3, img_size=64, slope=0.2): 126 | super(DCGAN_CondGenerator, self).__init__() 127 | self.nz = nz 128 | self.emb = nn.Embedding(num_classes, n_emb) 129 | if isinstance(img_size, (list, tuple)): 130 | self.init_size = ( img_size[0]//16, img_size[1]//16 ) 131 | else: 132 | self.init_size = ( img_size // 16, img_size // 16) 133 | 134 | self.project = nn.Sequential( 135 | Flatten(), 136 | nn.Linear(nz+n_emb, ngf*8*self.init_size[0]*self.init_size[1]), 137 | ) 138 | 139 | self.main = nn.Sequential( 140 | nn.BatchNorm2d(ngf*8), 141 | 142 | nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False), 143 | nn.BatchNorm2d(ngf*4), 144 | nn.LeakyReLU(slope, inplace=True), 145 | # 2x 146 | 147 | nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False), 148 | nn.BatchNorm2d(ngf*2), 149 | nn.LeakyReLU(slope, inplace=True), 150 | # 4x 151 | 152 | nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False), 153 | nn.BatchNorm2d(ngf), 154 | nn.LeakyReLU(slope, inplace=True), 155 | # 8x 156 | 157 | nn.ConvTranspose2d(ngf, ngf, 4, 2, 1, bias=False), 158 | nn.BatchNorm2d(ngf), 159 | nn.LeakyReLU(slope, inplace=True), 160 | # 16x 161 | 162 | nn.Conv2d(ngf, nc, 3, 1,1), 163 | #nn.Tanh(), 164 | nn.Sigmoid() 165 | ) 166 | 167 | def forward(self, z, y): 168 | y = self.emb(y) 169 | z = torch.cat([z, y], dim=1) 170 | proj = self.project(z) 171 | proj = proj.view(proj.shape[0], -1, self.init_size[0], self.init_size[1]) 172 | output = self.main(proj) 173 | return output 174 | 175 | class Discriminator(nn.Module): 176 | def __init__(self, nc=3, img_size=32): 177 | super(Discriminator, self).__init__() 178 | 179 | def discriminator_block(in_filters, out_filters, bn=True): 180 | block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] 181 | if bn: 182 | block.append(nn.BatchNorm2d(out_filters, 0.8)) 183 | return block 184 | 185 | self.model = nn.Sequential( 186 | *discriminator_block(nc, 16, bn=False), 187 | *discriminator_block(16, 32), 188 | *discriminator_block(32, 64), 189 | *discriminator_block(64, 128), 190 | ) 191 | 192 | # The height and width of downsampled image 193 | ds_size = img_size // 2 ** 4 194 | self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid()) 195 | 196 | def forward(self, img): 197 | out = self.model(img) 198 | out = out.view(out.shape[0], -1) 199 | validity = self.adv_layer(out) 200 | return validity 201 | 202 | class DCGAN_Discriminator(nn.Module): 203 | def __init__(self, nc=3, ndf=64): 204 | super(DCGAN_Discriminator, self).__init__() 205 | self.main = nn.Sequential( 206 | # input is (nc) x 64 x 64 207 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 208 | nn.LeakyReLU(0.2, inplace=True), 209 | # state size. (ndf) x 32 x 32 210 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 211 | nn.BatchNorm2d(ndf * 2), 212 | nn.LeakyReLU(0.2, inplace=True), 213 | # state size. (ndf*2) x 16 x 16 214 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 215 | nn.BatchNorm2d(ndf * 4), 216 | nn.LeakyReLU(0.2, inplace=True), 217 | # state size. (ndf*4) x 8 x 8 218 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 219 | nn.BatchNorm2d(ndf * 8), 220 | nn.LeakyReLU(0.2, inplace=True), 221 | # state size. (ndf*8) x 4 x 4 222 | nn.Conv2d(ndf * 8, 1, 2, 1, 0, bias=False), 223 | nn.Sigmoid() 224 | ) 225 | 226 | def forward(self, input): 227 | return self.main(input) -------------------------------------------------------------------------------- /datafree/rep_transfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class HintLoss(nn.Module): 7 | """Convolutional regression for FitNet""" 8 | def __init__(self, s_shapes, t_shapes, use_relu=False, loss_fn=F.mse_loss): 9 | super(HintLoss, self).__init__() 10 | self.use_relu = use_relu 11 | self.loss_fn = loss_fn 12 | regs = [] 13 | for s_shape, t_shape in zip(s_shapes, t_shapes): 14 | s_N, s_C, s_H, s_W = s_shape 15 | t_N, t_C, t_H, t_W = t_shape 16 | if s_H == 2 * t_H: 17 | conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) 18 | elif s_H * 2 == t_H: 19 | conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1) 20 | elif s_H >= t_H: 21 | conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) 22 | else: 23 | raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H)) 24 | reg = [conv, nn.BatchNorm2d(t_C)] 25 | if use_relu: 26 | reg.append( nn.ReLU(inplace=True) ) 27 | regs.append(nn.Sequential(*reg)) 28 | self.regs = nn.ModuleList(regs) 29 | 30 | def forward(self, s_features, t_features): 31 | loss = [] 32 | for reg, s_feat, t_feat in zip(self.regs, s_features, t_features): 33 | s_feat = reg(s_feat) 34 | loss.append( self.loss_fn( s_feat, t_feat ) ) 35 | return loss 36 | 37 | 38 | class ABLoss(nn.Module): 39 | """Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons 40 | code: https://github.com/bhheo/AB_distillation 41 | """ 42 | def __init__(self, s_shapes, t_shapes, margin=1.0, use_relu=False): 43 | super(ABLoss, self).__init__() 44 | 45 | regs = [] 46 | for s_shape, t_shape in zip(s_shapes, t_shapes): 47 | s_N, s_C, s_H, s_W = s_shape 48 | t_N, t_C, t_H, t_W = t_shape 49 | if s_H == 2 * t_H: 50 | conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) 51 | elif s_H * 2 == t_H: 52 | conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1) 53 | elif s_H >= t_H: 54 | conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) 55 | else: 56 | raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H)) 57 | reg = [conv, nn.BatchNorm2d(t_C)] 58 | if use_relu: 59 | reg.append( nn.ReLU(inplace=True) ) 60 | regs.append(nn.Sequential(*reg)) 61 | self.regs = nn.ModuleList(regs) 62 | feat_num = len(self.regs) 63 | self.w = [2**(i-feat_num+1) for i in range(feat_num)] 64 | self.margin = margin 65 | 66 | def forward(self, s_features, t_features, reverse=False): 67 | s_features = [ reg(s_feat) for (reg, s_feat) in zip(self.regs, s_features) ] 68 | bsz = s_features[0].shape[0] 69 | losses = [self.criterion_alternative_l2(s, t, reverse=reverse) for s, t in zip(s_features, t_features)] 70 | losses = [w * l for w, l in zip(self.w, losses)] 71 | losses = [l / bsz for l in losses] 72 | losses = [l / 1000 * 3 for l in losses] 73 | return losses 74 | 75 | def criterion_alternative_l2(self, source, target, reverse): 76 | if reverse: 77 | loss = ((source - self.margin) ** 2 * ((source < self.margin) & (target <= 0)).float() + 78 | (source + self.margin) ** 2 * ((source > -self.margin) & (target > 0)).float() + 79 | (target - self.margin) ** 2 * ((target < self.margin) & (source <= 0)).float() + 80 | (target + self.margin) ** 2 * ((target > -self.margin) & (source > 0)).float()) 81 | else: 82 | loss = ((source + self.margin) ** 2 * ((source > -self.margin) & (target <= 0)).float() + 83 | (source - self.margin) ** 2 * ((source <= self.margin) & (target > 0)).float()) 84 | return torch.abs(loss).sum() 85 | 86 | 87 | class RKDLoss(nn.Module): 88 | """Relational Knowledge Disitllation, CVPR2019""" 89 | def __init__(self, w_d=25, w_a=50, angle=True): 90 | super(RKDLoss, self).__init__() 91 | self.w_d = w_d 92 | self.w_a = w_a 93 | self.angle = angle 94 | 95 | def forward(self, s_features, t_features): 96 | losses = [] 97 | for f_s, f_t in zip(s_features, t_features): 98 | student = f_s.view(f_s.shape[0], -1) 99 | teacher = f_t.view(f_t.shape[0], -1) 100 | 101 | # RKD distance loss 102 | with torch.no_grad(): 103 | t_d = self.pdist(teacher, squared=False) 104 | mean_td = t_d[t_d > 0].mean() 105 | t_d = t_d / mean_td 106 | 107 | d = self.pdist(student, squared=False) 108 | mean_d = d[d > 0].mean() 109 | d = d / mean_d 110 | 111 | loss_d = F.smooth_l1_loss(d, t_d) 112 | 113 | if self.angle: 114 | # RKD Angle loss 115 | with torch.no_grad(): 116 | td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) 117 | norm_td = F.normalize(td, p=2, dim=2) 118 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) 119 | 120 | sd = (student.unsqueeze(0) - student.unsqueeze(1)) 121 | norm_sd = F.normalize(sd, p=2, dim=2) 122 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) 123 | 124 | loss_a = F.smooth_l1_loss(s_angle, t_angle) 125 | else: 126 | loss_a = 0 127 | loss = self.w_d * loss_d + self.w_a * loss_a 128 | losses.append(loss) 129 | return losses 130 | 131 | @staticmethod 132 | def pdist(e, squared=False, eps=1e-12): 133 | e_square = e.pow(2).sum(dim=1) 134 | prod = e @ e.t() 135 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 136 | 137 | if not squared: 138 | res = res.sqrt() 139 | 140 | res = res.clone() 141 | res[range(len(e)), range(len(e))] = 0 142 | return res 143 | 144 | 145 | class FSP(nn.Module): 146 | """A Gift from Knowledge Distillation: 147 | Fast Optimization, Network Minimization and Transfer Learning""" 148 | def __init__(self, s_shapes, t_shapes): 149 | super(FSP, self).__init__() 150 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' 151 | s_c = [s[1] for s in s_shapes] 152 | t_c = [t[1] for t in t_shapes] 153 | if np.any(np.asarray(s_c) != np.asarray(t_c)): 154 | raise ValueError('num of channels not equal (error in FSP)') 155 | 156 | def forward(self, g_s, g_t): 157 | s_fsp = self.compute_fsp(g_s) 158 | t_fsp = self.compute_fsp(g_t) 159 | loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)] 160 | return loss_group 161 | 162 | @staticmethod 163 | def compute_loss(s, t): 164 | return (s - t).pow(2).mean() 165 | 166 | @staticmethod 167 | def compute_fsp(g): 168 | fsp_list = [] 169 | for i in range(len(g) - 1): 170 | bot, top = g[i], g[i + 1] 171 | b_H, t_H = bot.shape[2], top.shape[2] 172 | if b_H > t_H: 173 | bot = F.adaptive_avg_pool2d(bot, (t_H, t_H)) 174 | elif b_H < t_H: 175 | top = F.adaptive_avg_pool2d(top, (b_H, b_H)) 176 | else: 177 | pass 178 | bot = bot.unsqueeze(1) 179 | top = top.unsqueeze(2) 180 | bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1) 181 | top = top.view(top.shape[0], top.shape[1], top.shape[2], -1) 182 | 183 | fsp = (bot * top).mean(-1) 184 | fsp_list.append(fsp) 185 | return fsp_list -------------------------------------------------------------------------------- /datafree/synthesis/__init__.py: -------------------------------------------------------------------------------- 1 | from .triplet import AdvTripletSynthesizer 2 | from .contrastive import CMISynthesizer 3 | from .base import BaseSynthesis -------------------------------------------------------------------------------- /datafree/synthesis/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from abc import ABC, abstractclassmethod 4 | from typing import Dict 5 | 6 | class BaseSynthesis(ABC): 7 | def __init__(self, teacher, student): 8 | super(BaseSynthesis, self).__init__() 9 | self.teacher = teacher 10 | self.student = student 11 | 12 | @abstractclassmethod 13 | def synthesize(self) -> Dict[str, torch.Tensor]: 14 | """ take several steps to synthesize new images and return an image dict for visualization. 15 | Returned images should be normalized to [0, 1]. 16 | """ 17 | pass 18 | 19 | @abstractclassmethod 20 | def sample(self, n): 21 | """ fetch a batch of training data. 22 | """ 23 | pass -------------------------------------------------------------------------------- /datafree/synthesis/contrastive.py: -------------------------------------------------------------------------------- 1 | import datafree 2 | from typing import Generator 3 | import torch 4 | from torch import optim 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import random 8 | 9 | from .base import BaseSynthesis 10 | from datafree.hooks import DeepInversionHook, InstanceMeanHook 11 | from datafree.criterions import jsdiv, get_image_prior_losses, kldiv 12 | from datafree.utils import ImagePool, DataIter, clip_images 13 | import collections 14 | from torchvision import transforms 15 | from kornia import augmentation 16 | from tqdm import tqdm 17 | 18 | class MLPHead(nn.Module): 19 | def __init__(self, dim_in, dim_feat, dim_h=None): 20 | super(MLPHead, self).__init__() 21 | if dim_h is None: 22 | dim_h = dim_in 23 | 24 | self.head = nn.Sequential( 25 | nn.Linear(dim_in, dim_h), 26 | nn.ReLU(inplace=True), 27 | nn.Linear(dim_h, dim_feat), 28 | ) 29 | 30 | def forward(self, x): 31 | x = self.head(x) 32 | return F.normalize(x, dim=1, p=2) 33 | 34 | class MultiTransform: 35 | """Create two crops of the same image""" 36 | def __init__(self, transform): 37 | self.transform = transform 38 | 39 | def __call__(self, x): 40 | return [t(x) for t in self.transform] 41 | 42 | def __repr__(self): 43 | return str( self.transform ) 44 | 45 | 46 | class ContrastLoss(nn.Module): 47 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 48 | It also supports the unsupervised contrastive loss in SimCLR. 49 | Adapted from https://github.com/HobbitLong/SupContrast/blob/master/losses.py""" 50 | def __init__(self, temperature=0.07, contrast_mode='all', 51 | base_temperature=0.07): 52 | super(ContrastLoss, self).__init__() 53 | self.temperature = temperature 54 | self.contrast_mode = contrast_mode 55 | self.base_temperature = base_temperature 56 | 57 | def forward(self, features, labels=None, mask=None, return_logits=False): 58 | """Compute loss for model. If both `labels` and `mask` are None, 59 | it degenerates to SimCLR unsupervised loss: 60 | https://arxiv.org/pdf/2002.05709.pdf 61 | Args: 62 | features: hidden vector of shape [bsz, n_views, ...]. 63 | labels: ground truth of shape [bsz]. 64 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 65 | has the same class as sample i. Can be asymmetric. 66 | Returns: 67 | A loss scalar. 68 | """ 69 | device = (torch.device('cuda') 70 | if features.is_cuda 71 | else torch.device('cpu')) 72 | 73 | if len(features.shape) < 3: 74 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 75 | 'at least 3 dimensions are required') 76 | if len(features.shape) > 3: 77 | features = features.view(features.shape[0], features.shape[1], -1) 78 | 79 | batch_size = features.shape[0] 80 | if labels is not None and mask is not None: 81 | raise ValueError('Cannot define both `labels` and `mask`') 82 | elif labels is None and mask is None: 83 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 84 | elif labels is not None: 85 | labels = labels.contiguous().view(-1, 1) 86 | if labels.shape[0] != batch_size: 87 | raise ValueError('Num of labels does not match num of features') 88 | mask = torch.eq(labels, labels.T).float().to(device) 89 | else: 90 | mask = mask.float().to(device) 91 | 92 | contrast_count = features.shape[1] 93 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 94 | if self.contrast_mode == 'one': 95 | anchor_feature = features[:, 0] 96 | anchor_count = 1 97 | elif self.contrast_mode == 'all': 98 | anchor_feature = contrast_feature 99 | anchor_count = contrast_count 100 | else: 101 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 102 | 103 | # compute logits 104 | anchor_dot_contrast = torch.div( 105 | torch.matmul(anchor_feature, contrast_feature.T), 106 | self.temperature) 107 | # for numerical stability 108 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 109 | logits = anchor_dot_contrast - logits_max.detach() 110 | 111 | # tile mask 112 | mask = mask.repeat(anchor_count, contrast_count) 113 | # mask-out self-contrast cases 114 | logits_mask = torch.scatter( 115 | torch.ones_like(mask), 116 | 1, 117 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 118 | 0 119 | ) 120 | mask = mask * logits_mask 121 | 122 | # compute log_prob 123 | exp_logits = torch.exp(logits) * logits_mask 124 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 125 | 126 | # compute mean of log-likelihood over positive 127 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 128 | # loss 129 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 130 | loss = loss.view(anchor_count, batch_size) 131 | 132 | if return_logits: 133 | return loss, anchor_dot_contrast 134 | return loss 135 | 136 | 137 | class MemoryBank(object): 138 | def __init__(self, device, max_size=4096, dim_feat=512): 139 | self.device = device 140 | self.data = torch.randn( max_size, dim_feat ).to(device) 141 | self._ptr = 0 142 | self.n_updates = 0 143 | 144 | self.max_size = max_size 145 | self.dim_feat = dim_feat 146 | 147 | def add(self, feat): 148 | feat = feat.to(self.device) 149 | n, c = feat.shape 150 | assert self.dim_feat==c and self.max_size % n==0, "%d, %d"%(self.dim_feat, c, self.max_size, n) 151 | self.data[self._ptr:self._ptr+n] = feat.detach() 152 | self._ptr = (self._ptr+n) % (self.max_size) 153 | self.n_updates+=n 154 | 155 | def get_data(self, k=None, index=None): 156 | if k is None: 157 | k = self.max_size 158 | 159 | if self.n_updates>self.max_size: 160 | if index is None: 161 | index = random.sample(list(range(self.max_size)), k=k) 162 | return self.data[index], index 163 | else: 164 | #return self.data[:self._ptr] 165 | if index is None: 166 | index = random.sample(list(range(self._ptr)), k=min(k, self._ptr)) 167 | return self.data[index], index 168 | 169 | def reset_model(model): 170 | for m in model.modules(): 171 | if isinstance(m, (nn.ConvTranspose2d, nn.Linear, nn.Conv2d)): 172 | nn.init.normal_(m.weight, 0.0, 0.02) 173 | if m.bias is not None: 174 | nn.init.constant_(m.bias, 0) 175 | if isinstance(m, (nn.BatchNorm2d)): 176 | nn.init.normal_(m.weight, 1.0, 0.02) 177 | nn.init.constant_(m.bias, 0) 178 | 179 | class CMISynthesizer(BaseSynthesis): 180 | def __init__(self, teacher, student, generator, nz, num_classes, img_size, 181 | feature_layers=None, bank_size=40960, n_neg=4096, head_dim=128, init_dataset=None, 182 | iterations=100, lr_g=0.1, progressive_scale=False, 183 | synthesis_batch_size=128, sample_batch_size=128, 184 | adv=0.0, bn=1, oh=1, cr=0.8, cr_T=0.1, 185 | save_dir='run/cmi', transform=None, 186 | autocast=None, use_fp16=False, 187 | normalizer=None, device='cpu', distributed=False): 188 | super(CMISynthesizer, self).__init__(teacher, student) 189 | self.save_dir = save_dir 190 | self.img_size = img_size 191 | self.iterations = iterations 192 | self.lr_g = lr_g 193 | self.progressive_scale = progressive_scale 194 | self.nz = nz 195 | self.n_neg = n_neg 196 | self.adv = adv 197 | self.bn = bn 198 | self.oh = oh 199 | self.num_classes = num_classes 200 | self.distributed = distributed 201 | self.synthesis_batch_size = synthesis_batch_size 202 | self.sample_batch_size = sample_batch_size 203 | self.bank_size = bank_size 204 | self.init_dataset = init_dataset 205 | 206 | self.use_fp16 = use_fp16 207 | self.autocast = autocast # for FP16 208 | self.normalizer = normalizer 209 | self.data_pool = ImagePool(root=self.save_dir) 210 | self.transform = transform 211 | self.data_iter = None 212 | 213 | self.cr = cr 214 | self.cr_T = cr_T 215 | self.cmi_hooks = [] 216 | if feature_layers is not None: 217 | for layer in feature_layers: 218 | self.cmi_hooks.append( InstanceMeanHook(layer) ) 219 | else: 220 | for m in teacher.modules(): 221 | if isinstance(m, nn.BatchNorm2d): 222 | self.cmi_hooks.append( InstanceMeanHook(m) ) 223 | 224 | with torch.no_grad(): 225 | teacher.eval() 226 | fake_inputs = torch.randn(size=(1, *img_size), device=device) 227 | _ = teacher(fake_inputs) 228 | cmi_feature = torch.cat([ h.instance_mean for h in self.cmi_hooks ], dim=1) 229 | print("CMI dims: %d"%(cmi_feature.shape[1])) 230 | del fake_inputs 231 | 232 | self.generator = generator.to(device).train() 233 | # local and global bank 234 | self.mem_bank = MemoryBank('cpu', max_size=self.bank_size, dim_feat=2*cmi_feature.shape[1]) # local + global 235 | 236 | self.head = MLPHead(cmi_feature.shape[1], head_dim).to(device).train() 237 | self.optimizer_head = torch.optim.Adam(self.head.parameters(), lr=self.lr_g) 238 | 239 | self.device = device 240 | self.hooks = [] 241 | for m in teacher.modules(): 242 | if isinstance(m, nn.BatchNorm2d): 243 | self.hooks.append( DeepInversionHook(m) ) 244 | 245 | self.aug = MultiTransform([ 246 | # global view 247 | transforms.Compose([ 248 | augmentation.RandomCrop(size=[self.img_size[-2], self.img_size[-1]], padding=4), 249 | augmentation.RandomHorizontalFlip(), 250 | normalizer, 251 | ]), 252 | # local view 253 | transforms.Compose([ 254 | augmentation.RandomResizedCrop(size=[self.img_size[-2], self.img_size[-1]], scale=[0.25, 1.0]), 255 | augmentation.RandomHorizontalFlip(), 256 | normalizer, 257 | ]), 258 | ]) 259 | 260 | #self.contrast_loss = ContrastLoss(temperature=self.cr_T, contrast_mode='one') 261 | 262 | def synthesize(self, targets=None): 263 | self.student.eval() 264 | self.teacher.eval() 265 | best_cost = 1e6 266 | 267 | #inputs = torch.randn( size=(self.synthesis_batch_size, *self.img_size), device=self.device ).requires_grad_() 268 | best_inputs = None 269 | z = torch.randn(size=(self.synthesis_batch_size, self.nz), device=self.device).requires_grad_() 270 | if targets is None: 271 | targets = torch.randint(low=0, high=self.num_classes, size=(self.synthesis_batch_size,)) 272 | targets = targets.sort()[0] # sort for better visualization 273 | targets = targets.to(self.device) 274 | 275 | reset_model(self.generator) 276 | optimizer = torch.optim.Adam([{'params': self.generator.parameters()}, {'params': [z]}], self.lr_g, betas=[0.5, 0.999]) 277 | #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.iterations, eta_min=0.1*self.lr) 278 | for it in tqdm(range(self.iterations)): 279 | inputs = self.generator(z) 280 | global_view, local_view = self.aug(inputs) # crop and normalize 281 | 282 | ############################################# 283 | # Inversion Loss 284 | ############################################# 285 | t_out = self.teacher(global_view) 286 | loss_bn = sum([h.r_feature for h in self.hooks]) 287 | loss_oh = F.cross_entropy( t_out, targets ) 288 | if self.adv>0: 289 | s_out = self.student(global_view) 290 | mask = (s_out.max(1)[1]==t_out.max(1)[1]).float() 291 | loss_adv = -(kldiv(s_out, t_out, reduction='none').sum(1) * mask).mean() # decision adversarial distillation 292 | else: 293 | loss_adv = loss_oh.new_zeros(1) 294 | loss_inv = self.bn * loss_bn + self.oh * loss_oh + self.adv * loss_adv 295 | 296 | ############################################# 297 | # Contrastive Loss 298 | ############################################# 299 | global_feature = torch.cat([ h.instance_mean for h in self.cmi_hooks ], dim=1) 300 | _ = self.teacher(local_view) 301 | local_feature = torch.cat([ h.instance_mean for h in self.cmi_hooks ], dim=1) 302 | cached_feature, _ = self.mem_bank.get_data(self.n_neg) 303 | cached_local_feature, cached_global_feature = torch.chunk(cached_feature.to(self.device), chunks=2, dim=1) 304 | 305 | proj_feature = self.head( torch.cat([local_feature, cached_local_feature, global_feature, cached_global_feature], dim=0) ) 306 | proj_local_feature, proj_global_feature = torch.chunk(proj_feature, chunks=2, dim=0) 307 | 308 | # https://github.com/HobbitLong/SupContrast/blob/master/losses.py 309 | #cr_feature = torch.cat( [proj_local_feature.unsqueeze(1), proj_global_feature.unsqueeze(1).detach()], dim=1 ) 310 | #loss_cr = self.contrast_loss(cr_feature) 311 | 312 | # Note that the cross entropy loss will be divided by the total batch size (current batch + cached batch) 313 | # we split the cross entropy loss to avoid too small gradients w.r.t the generator 314 | #if self.mem_bank.n_updates>0: 315 | # 1. gradient from current batch + 2. gradient from cached data 316 | # loss_cr = loss_cr[:, :self.synthesis_batch_size].mean() + loss_cr[:, self.synthesis_batch_size:].mean() 317 | #else: # 1. gradients only come from current batch 318 | # loss_cr = loss_cr.mean() 319 | 320 | # A naive implementation of contrastive loss 321 | cr_logits = torch.mm(proj_local_feature, proj_global_feature.detach().T) / self.cr_T # (N + N') x (N + N') 322 | cr_labels = torch.arange(start=0, end=len(cr_logits), device=self.device) 323 | loss_cr = F.cross_entropy( cr_logits, cr_labels, reduction='none') #(N + N') 324 | if self.mem_bank.n_updates>0: 325 | loss_cr = loss_cr[:self.synthesis_batch_size].mean() + loss_cr[self.synthesis_batch_size:].mean() 326 | else: 327 | loss_cr = loss_cr.mean() 328 | 329 | loss = self.cr * loss_cr + loss_inv 330 | with torch.no_grad(): 331 | if best_cost > loss.item() or best_inputs is None: 332 | best_cost = loss.item() 333 | best_inputs = inputs.data 334 | best_features = torch.cat([local_feature.data, global_feature.data], dim=1).data 335 | optimizer.zero_grad() 336 | self.optimizer_head.zero_grad() 337 | loss.backward() 338 | optimizer.step() 339 | self.optimizer_head.step() 340 | 341 | self.student.train() 342 | # save best inputs and reset data iter 343 | self.data_pool.add( best_inputs ) 344 | self.mem_bank.add( best_features ) 345 | 346 | dst = self.data_pool.get_dataset(transform=self.transform) 347 | if self.init_dataset is not None: 348 | init_dst = datafree.utils.UnlabeledImageDataset(self.init_dataset, transform=self.transform) 349 | dst = torch.utils.data.ConcatDataset([dst, init_dst]) 350 | if self.distributed: 351 | train_sampler = torch.utils.data.distributed.DistributedSampler(dst) 352 | else: 353 | train_sampler = None 354 | loader = torch.utils.data.DataLoader( 355 | dst, batch_size=self.sample_batch_size, shuffle=(train_sampler is None), 356 | num_workers=4, pin_memory=True, sampler=train_sampler) 357 | self.data_iter = DataIter(loader) 358 | return {"synthetic": best_inputs} 359 | 360 | def sample(self): 361 | return self.data_iter.next() -------------------------------------------------------------------------------- /datafree/synthesis/triplet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from losses import TriLoss 6 | 7 | from .base import BaseSynthesis 8 | from datafree.hooks import DeepInversionHook 9 | from datafree.criterions import kldiv 10 | from datafree.utils import ImagePool, DataIter 11 | from torchvision import transforms 12 | from kornia import augmentation 13 | from tqdm import tqdm 14 | 15 | 16 | def reset_model(model): 17 | for m in model.modules(): 18 | if isinstance(m, (nn.ConvTranspose2d, nn.Linear, nn.Conv2d)): 19 | nn.init.normal_(m.weight, 0.0, 0.02) 20 | if m.bias is not None: 21 | nn.init.constant_(m.bias, 0) 22 | if isinstance(m, (nn.BatchNorm2d)): 23 | nn.init.normal_(m.weight, 1.0, 0.02) 24 | nn.init.constant_(m.bias, 0) 25 | 26 | 27 | class AdvTripletSynthesizer(BaseSynthesis): 28 | def __init__(self, teacher, student, generator, pair_sample, nz, num_classes, img_size, 29 | start_layer, end_layer, iterations=100, lr_g=0.1, progressive_scale=False, 30 | synthesis_batch_size=128, sample_batch_size=128, 31 | adv=0.0, bn=1, oh=1, triplet=0.0, 32 | save_dir='run/cmi', transform=None, 33 | normalizer=None, device='cpu', distributed=False, 34 | triplet_target='teacher', balanced_sampling=False): 35 | super(AdvTripletSynthesizer, self).__init__(teacher, student) 36 | self.save_dir = save_dir 37 | self.img_size = img_size 38 | self.start_layer = start_layer 39 | self.end_layer = end_layer 40 | self.iterations = iterations 41 | self.lr_g = lr_g 42 | self.progressive_scale = progressive_scale 43 | self.nz = nz 44 | self.adv = adv 45 | self.bn = bn 46 | self.oh = oh 47 | self.triplet = triplet 48 | self.compute_triplet_loss = TriLoss(balanced_sampling=balanced_sampling) 49 | self.num_classes = num_classes 50 | self.distributed = distributed 51 | self.synthesis_batch_size = synthesis_batch_size 52 | self.sample_batch_size = sample_batch_size 53 | self.triplet_target = triplet_target 54 | 55 | self.normalizer = normalizer 56 | self.data_pool = ImagePool(root=self.save_dir) 57 | self.transform = transform 58 | self.data_iter = None 59 | self.generator = generator.to(device).train() 60 | # local and global bank 61 | 62 | self.device = device 63 | self.hooks = [] 64 | for m in teacher.modules(): 65 | if isinstance(m, nn.BatchNorm2d): 66 | self.hooks.append(DeepInversionHook(m)) 67 | 68 | self.aug = transforms.Compose([ 69 | augmentation.RandomCrop( 70 | size=[self.img_size[-2], self.img_size[-1]], padding=4), 71 | augmentation.RandomHorizontalFlip(), 72 | normalizer, 73 | ]) 74 | self.pair_sample = pair_sample 75 | 76 | def synthesize(self, targets=None): 77 | self.student.eval() 78 | self.teacher.eval() 79 | best_cost = 1e6 80 | 81 | #inputs = torch.randn( size=(self.synthesis_batch_size, *self.img_size), device=self.device ).requires_grad_() 82 | best_inputs = None 83 | z = torch.randn(size=(self.synthesis_batch_size, self.nz), 84 | device=self.device).requires_grad_() 85 | if targets is None: 86 | targets = torch.randint( 87 | low=0, high=self.num_classes, size=(self.synthesis_batch_size,)) 88 | targets = targets.sort()[0] # sort for better visualization 89 | targets = targets.to(self.device) 90 | 91 | reset_model(self.generator) 92 | optimizer = torch.optim.Adam([{'params': self.generator.parameters()}, { 93 | 'params': [z]}], self.lr_g, betas=[0.5, 0.999]) 94 | #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.iterations, eta_min=0.1*self.lr) 95 | for _ in tqdm(range(self.iterations)): 96 | inputs = self.generator(z) 97 | global_view = self.aug(inputs) # crop and normalize 98 | 99 | ############################################# 100 | # Inversion Loss 101 | ############################################# 102 | t_out, _, t_layers = self.teacher( 103 | global_view, return_features=True) 104 | s_out, _, s_layers = self.student( 105 | global_view, return_features=True) 106 | loss_bn = sum([h.r_feature for h in self.hooks]) 107 | loss_oh = F.cross_entropy(t_out, targets) 108 | 109 | if self.adv > 0: 110 | mask = (s_out.max(1)[1] == t_out.max(1)[1]).float() 111 | # decision adversarial distillation 112 | loss_adv = - \ 113 | (kldiv(s_out, t_out, reduction='none').sum(1) * mask).mean() 114 | else: 115 | loss_adv = loss_oh.new_zeros(1) 116 | 117 | if self.triplet_target == 'teacher': 118 | triplet_layers = t_layers 119 | elif self.triplet_target == 'student': 120 | triplet_layers = s_layers 121 | else: 122 | raise NotImplementedError() 123 | 124 | if self.triplet > 0: 125 | loss_tri = self.compute_triplet_loss( 126 | triplet_layers[self.start_layer:self.end_layer], t_out, torch.argmax(t_out, dim=-1)) 127 | else: 128 | loss_tri = loss_oh.new_zeros(1) 129 | 130 | loss = self.bn * loss_bn + self.oh * loss_oh + \ 131 | self.adv * loss_adv + self.triplet * loss_tri 132 | 133 | with torch.no_grad(): 134 | if best_cost > loss.item() or best_inputs is None: 135 | best_cost = loss.item() 136 | best_inputs = inputs.data 137 | optimizer.zero_grad() 138 | loss.backward() 139 | optimizer.step() 140 | 141 | self.student.train() 142 | self.data_pool.add(best_inputs) 143 | 144 | dst = self.data_pool.get_dataset( 145 | transform=self.transform, pair_sample=self.pair_sample) 146 | if self.distributed: 147 | train_sampler = torch.utils.data.distributed.DistributedSampler( 148 | dst) 149 | else: 150 | train_sampler = None 151 | loader = torch.utils.data.DataLoader( 152 | dst, batch_size=self.sample_batch_size, shuffle=( 153 | train_sampler is None), 154 | num_workers=4, pin_memory=True, sampler=train_sampler) 155 | self.data_iter = DataIter(loader) 156 | print("sample_batch_size:", self.sample_batch_size) 157 | return {"synthetic": best_inputs, "targets": targets} 158 | 159 | def sample(self): 160 | return self.data_iter.next() 161 | -------------------------------------------------------------------------------- /datafree/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._utils import * 2 | from .logger import get_logger 3 | 4 | from . import sync_transforms, inception -------------------------------------------------------------------------------- /datafree/utils/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import ConcatDataset, Dataset 3 | import numpy as np 4 | from PIL import Image 5 | import os, random, math 6 | from copy import deepcopy 7 | from contextlib import contextmanager 8 | 9 | def get_pseudo_label(n_or_label, num_classes, device, onehot=False): 10 | if isinstance(n_or_label, int): 11 | label = torch.randint(0, num_classes, size=(n_or_label,), device=device) 12 | else: 13 | label = n_or_label.to(device) 14 | if onehot: 15 | label = torch.zeros(len(label), num_classes, device=device).scatter_(1, label.unsqueeze(1), 1.) 16 | return label 17 | 18 | def pdist(sample_1, sample_2, norm=2, eps=1e-5): 19 | r"""Compute the matrix of all squared pairwise distances. 20 | Arguments 21 | --------- 22 | sample_1 : torch.Tensor or Variable 23 | The first sample, should be of shape ``(n_1, d)``. 24 | sample_2 : torch.Tensor or Variable 25 | The second sample, should be of shape ``(n_2, d)``. 26 | norm : float 27 | The l_p norm to be used. 28 | Returns 29 | ------- 30 | torch.Tensor or Variable 31 | Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to 32 | ``|| sample_1[i, :] - sample_2[j, :] ||_p``.""" 33 | n_1, n_2 = sample_1.size(0), sample_2.size(0) 34 | norm = float(norm) 35 | if norm == 2.: 36 | norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True) 37 | norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True) 38 | norms = (norms_1.expand(n_1, n_2) + 39 | norms_2.transpose(0, 1).expand(n_1, n_2)) 40 | distances_squared = norms - 2 * sample_1.mm(sample_2.t()) 41 | return torch.sqrt(eps + torch.abs(distances_squared)) 42 | else: 43 | dim = sample_1.size(1) 44 | expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim) 45 | expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim) 46 | differences = torch.abs(expanded_1 - expanded_2) ** norm 47 | inner = torch.sum(differences, dim=2, keepdim=False) 48 | return (eps + inner) ** (1. / norm) 49 | 50 | class MemoryBank(object): 51 | def __init__(self, device, max_size=4096, dim_feat=512): 52 | self.device = device 53 | self.data = torch.randn( max_size, dim_feat ).to(device) 54 | self._ptr = 0 55 | self.n_updates = 0 56 | 57 | self.max_size = max_size 58 | self.dim_feat = dim_feat 59 | 60 | def add(self, feat): 61 | n, c = feat.shape 62 | assert self.dim_feat==c and self.max_size % n==0, "%d, %d"%(dim_feat, c, max_size, n) 63 | self.data[self._ptr:self._ptr+n] = feat.detach() 64 | self._ptr = (self._ptr+n) % (self.max_size) 65 | self.n_updates+=n 66 | 67 | def get_data(self, k=None, index=None): 68 | if k is None: 69 | k = self.max_size 70 | assert k <= self.max_size 71 | 72 | if self.n_updates>self.max_size: 73 | if index is None: 74 | index = random.sample(list(range(self.max_size)), k=k) 75 | return self.data[index], index 76 | else: 77 | if index is None: 78 | index = random.sample(list(range(self._ptr)), k=min(k, self._ptr)) 79 | return self.data[index], index 80 | 81 | def clip_images(image_tensor, mean, std): 82 | mean = np.array(mean) 83 | std = np.array(std) 84 | for c in range(3): 85 | m, s = mean[c], std[c] 86 | image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s) 87 | return image_tensor 88 | 89 | def save_image_batch(imgs, output, col=None, size=None, pack=True): 90 | if isinstance(imgs, torch.Tensor): 91 | imgs = (imgs.detach().clamp(0, 1).cpu().numpy()*255).astype('uint8') 92 | base_dir = os.path.dirname(output) 93 | if base_dir!='': 94 | os.makedirs(base_dir, exist_ok=True) 95 | if pack: 96 | imgs = pack_images( imgs, col=col ).transpose( 1, 2, 0 ).squeeze() 97 | imgs = Image.fromarray( imgs ) 98 | if size is not None: 99 | if isinstance(size, (list,tuple)): 100 | imgs = imgs.resize(size) 101 | else: 102 | w, h = imgs.size 103 | max_side = max( h, w ) 104 | scale = float(size) / float(max_side) 105 | _w, _h = int(w*scale), int(h*scale) 106 | imgs = imgs.resize([_w, _h]) 107 | imgs.save(output) 108 | else: 109 | output_filename = output.strip('.png') 110 | for idx, img in enumerate(imgs): 111 | img = Image.fromarray( img.transpose(1, 2, 0) ) 112 | img.save(output_filename+'-%d.png'%(idx)) 113 | 114 | def pack_images(images, col=None, channel_last=False, padding=1): 115 | # N, C, H, W 116 | if isinstance(images, (list, tuple) ): 117 | images = np.stack(images, 0) 118 | if channel_last: 119 | images = images.transpose(0,3,1,2) # make it channel first 120 | assert len(images.shape)==4 121 | assert isinstance(images, np.ndarray) 122 | 123 | N,C,H,W = images.shape 124 | if col is None: 125 | col = int(math.ceil(math.sqrt(N))) 126 | row = int(math.ceil(N / col)) 127 | 128 | pack = np.zeros( (C, H*row+padding*(row-1), W*col+padding*(col-1)), dtype=images.dtype ) 129 | for idx, img in enumerate(images): 130 | h = (idx // col) * (H+padding) 131 | w = (idx % col) * (W+padding) 132 | pack[:, h:h+H, w:w+W] = img 133 | return pack 134 | 135 | def flatten_dict(dic): 136 | flattned = dict() 137 | def _flatten(prefix, d): 138 | for k, v in d.items(): 139 | if isinstance(v, dict): 140 | if prefix is None: 141 | _flatten( k, v ) 142 | else: 143 | _flatten( prefix+'/%s'%k, v ) 144 | else: 145 | if prefix is None: 146 | flattned[k] = v 147 | else: 148 | flattned[ prefix+'/%s'%k ] = v 149 | 150 | _flatten(None, dic) 151 | return flattned 152 | 153 | def normalize(tensor, mean, std, reverse=False): 154 | if reverse: 155 | _mean = [ -m / s for m, s in zip(mean, std) ] 156 | _std = [ 1/s for s in std ] 157 | else: 158 | _mean = mean 159 | _std = std 160 | 161 | _mean = torch.as_tensor(_mean, dtype=tensor.dtype, device=tensor.device) 162 | _std = torch.as_tensor(_std, dtype=tensor.dtype, device=tensor.device) 163 | tensor = (tensor - _mean[None, :, None, None]) / (_std[None, :, None, None]) 164 | return tensor 165 | 166 | class Normalizer(object): 167 | def __init__(self, mean, std): 168 | self.mean = mean 169 | self.std = std 170 | 171 | def __call__(self, x, reverse=False): 172 | return normalize(x, self.mean, self.std, reverse=reverse) 173 | 174 | def load_yaml(filepath): 175 | yaml=YAML() 176 | with open(filepath, 'r') as f: 177 | return yaml.load(f) 178 | 179 | def _collect_all_images(root, postfix=['png', 'jpg', 'jpeg', 'JPEG']): 180 | images = [] 181 | if isinstance( postfix, str): 182 | postfix = [ postfix ] 183 | for dirpath, dirnames, files in os.walk(root): 184 | for pos in postfix: 185 | for f in files: 186 | if f.endswith( pos ): 187 | images.append( os.path.join( dirpath, f ) ) 188 | return images 189 | 190 | class UnlabeledImageDataset(torch.utils.data.Dataset): 191 | def __init__(self, root, transform=None, pair_sample=False): 192 | self.root = os.path.abspath(root) 193 | self.images = _collect_all_images(self.root) #[ os.path.join(self.root, f) for f in os.listdir( root ) ] 194 | self.transform = transform 195 | self.pair_sample = pair_sample 196 | 197 | def __getitem__(self, idx): 198 | img = Image.open( self.images[idx] ) 199 | if self.transform: 200 | img1 = self.transform(img) 201 | if not self.pair_sample: 202 | return img1 203 | img2 = self.transform(img) 204 | return img1, img2 205 | 206 | def __len__(self): 207 | return len(self.images) 208 | 209 | def __repr__(self): 210 | return 'Unlabeled data:\n\troot: %s\n\tdata mount: %d\n\ttransforms: %s'%(self.root, len(self), self.transform) 211 | 212 | class LabeledImageDataset(torch.utils.data.Dataset): 213 | def __init__(self, root, transform=None): 214 | self.root = os.path.abspath(root) 215 | self.categories = [ int(f) for f in os.listdir( root ) ] 216 | images = [] 217 | targets = [] 218 | for c in self.categories: 219 | category_dir = os.path.join( self.root, str(c)) 220 | _images = [ os.path.join( category_dir, f ) for f in os.listdir(category_dir) ] 221 | images.extend(_images) 222 | targets.extend([c for _ in range(len(_images))]) 223 | self.images = images 224 | self.targets = targets 225 | self.transform = transform 226 | def __getitem__(self, idx): 227 | img, target = Image.open( self.images[idx] ), self.targets[idx] 228 | if self.transform: 229 | img = self.transform(img) 230 | return img, target 231 | def __len__(self): 232 | return len(self.images) 233 | 234 | class ImagePool(object): 235 | def __init__(self, root): 236 | self.root = os.path.abspath(root) 237 | os.makedirs(self.root, exist_ok=True) 238 | self._idx = 0 239 | 240 | def add(self, imgs, targets=None): 241 | save_image_batch(imgs, os.path.join( self.root, "%d.png"%(self._idx) ), pack=False) 242 | self._idx+=1 243 | 244 | def get_dataset(self, transform=None, pair_sample=False): 245 | return UnlabeledImageDataset(self.root, transform=transform, pair_sample=pair_sample) 246 | 247 | class DataIter(object): 248 | def __init__(self, dataloader): 249 | self.dataloader = dataloader 250 | self._iter = iter(self.dataloader) 251 | 252 | def next(self): 253 | try: 254 | data = next( self._iter ) 255 | except StopIteration: 256 | self._iter = iter(self.dataloader) 257 | data = next( self._iter ) 258 | return data 259 | 260 | @contextmanager 261 | def dummy_ctx(*args, **kwds): 262 | try: 263 | yield None 264 | finally: 265 | pass -------------------------------------------------------------------------------- /datafree/utils/fmix.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy as np 5 | from scipy.stats import beta 6 | 7 | 8 | def fftfreqnd(h, w=None, z=None): 9 | """ Get bin values for discrete fourier transform of size (h, w, z) 10 | :param h: Required, first dimension size 11 | :param w: Optional, second dimension size 12 | :param z: Optional, third dimension size 13 | """ 14 | fz = fx = 0 15 | fy = np.fft.fftfreq(h) 16 | 17 | if w is not None: 18 | fy = np.expand_dims(fy, -1) 19 | 20 | if w % 2 == 1: 21 | fx = np.fft.fftfreq(w)[: w // 2 + 2] 22 | else: 23 | fx = np.fft.fftfreq(w)[: w // 2 + 1] 24 | 25 | if z is not None: 26 | fy = np.expand_dims(fy, -1) 27 | if z % 2 == 1: 28 | fz = np.fft.fftfreq(z)[:, None] 29 | else: 30 | fz = np.fft.fftfreq(z)[:, None] 31 | 32 | return np.sqrt(fx * fx + fy * fy + fz * fz) 33 | 34 | 35 | def get_spectrum(freqs, decay_power, ch, h, w=0, z=0): 36 | """ Samples a fourier image with given size and frequencies decayed by decay power 37 | :param freqs: Bin values for the discrete fourier transform 38 | :param decay_power: Decay power for frequency decay prop 1/f**d 39 | :param ch: Number of channels for the resulting mask 40 | :param h: Required, first dimension size 41 | :param w: Optional, second dimension size 42 | :param z: Optional, third dimension size 43 | """ 44 | scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power) 45 | 46 | param_size = [ch] + list(freqs.shape) + [2] 47 | param = np.random.randn(*param_size) 48 | 49 | scale = np.expand_dims(scale, -1)[None, :] 50 | 51 | return scale * param 52 | 53 | 54 | def make_low_freq_image(decay, shape, ch=1): 55 | """ Sample a low frequency image from fourier space 56 | :param decay_power: Decay power for frequency decay prop 1/f**d 57 | :param shape: Shape of desired mask, list up to 3 dims 58 | :param ch: Number of channels for desired mask 59 | """ 60 | freqs = fftfreqnd(*shape) 61 | spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1)) 62 | spectrum = spectrum[:, 0] + 1j * spectrum[:, 1] 63 | mask = np.real(np.fft.irfftn(spectrum, shape)) 64 | 65 | if len(shape) == 1: 66 | mask = mask[:1, :shape[0]] 67 | if len(shape) == 2: 68 | mask = mask[:1, :shape[0], :shape[1]] 69 | if len(shape) == 3: 70 | mask = mask[:1, :shape[0], :shape[1], :shape[2]] 71 | 72 | mask = mask 73 | mask = (mask - mask.min()) 74 | mask = mask / mask.max() 75 | return mask 76 | 77 | 78 | def sample_lam(alpha, reformulate=False): 79 | """ Sample a lambda from symmetric beta distribution with given alpha 80 | :param alpha: Alpha value for beta distribution 81 | :param reformulate: If True, uses the reformulation of [1]. 82 | """ 83 | if reformulate: 84 | lam = beta.rvs(alpha+1, alpha) 85 | else: 86 | lam = beta.rvs(alpha, alpha) 87 | 88 | return lam 89 | 90 | 91 | def binarise_mask(mask, lam, in_shape, max_soft=0.0): 92 | """ Binarises a given low frequency image such that it has mean lambda. 93 | :param mask: Low frequency image, usually the result of `make_low_freq_image` 94 | :param lam: Mean value of final mask 95 | :param in_shape: Shape of inputs 96 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. 97 | :return: 98 | """ 99 | idx = mask.reshape(-1).argsort()[::-1] 100 | mask = mask.reshape(-1) 101 | num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size) 102 | 103 | eff_soft = max_soft 104 | if max_soft > lam or max_soft > (1-lam): 105 | eff_soft = min(lam, 1-lam) 106 | 107 | soft = int(mask.size * eff_soft) 108 | num_low = num - soft 109 | num_high = num + soft 110 | 111 | mask[idx[:num_high]] = 1 112 | mask[idx[num_low:]] = 0 113 | mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low)) 114 | 115 | mask = mask.reshape((1, *in_shape)) 116 | return mask 117 | 118 | 119 | def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False): 120 | """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises 121 | it based on this lambda 122 | :param alpha: Alpha value for beta distribution from which to sample mean of mask 123 | :param decay_power: Decay power for frequency decay prop 1/f**d 124 | :param shape: Shape of desired mask, list up to 3 dims 125 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. 126 | :param reformulate: If True, uses the reformulation of [1]. 127 | """ 128 | if isinstance(shape, int): 129 | shape = (shape,) 130 | 131 | # Choose lambda 132 | lam = sample_lam(alpha, reformulate) 133 | 134 | # Make mask, get mean / std 135 | mask = make_low_freq_image(decay_power, shape) 136 | mask = binarise_mask(mask, lam, shape, max_soft) 137 | 138 | return lam, mask 139 | 140 | 141 | def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False): 142 | """ 143 | :param x: Image batch on which to apply fmix of shape [b, c, shape*] 144 | :param alpha: Alpha value for beta distribution from which to sample mean of mask 145 | :param decay_power: Decay power for frequency decay prop 1/f**d 146 | :param shape: Shape of desired mask, list up to 3 dims 147 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. 148 | :param reformulate: If True, uses the reformulation of [1]. 149 | :return: mixed input, permutation indices, lambda value of mix, 150 | """ 151 | lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate) 152 | index = np.random.permutation(x.shape[0]) 153 | 154 | x1, x2 = x * mask, x[index] * (1-mask) 155 | return x1+x2, index, lam 156 | 157 | 158 | class FMixBase: 159 | r""" FMix augmentation 160 | Args: 161 | decay_power (float): Decay power for frequency decay prop 1/f**d 162 | alpha (float): Alpha value for beta distribution from which to sample mean of mask 163 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims 164 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask. 165 | reformulate (bool): If True, uses the reformulation of [1]. 166 | """ 167 | 168 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False): 169 | super().__init__() 170 | self.decay_power = decay_power 171 | self.reformulate = reformulate 172 | self.size = size 173 | self.alpha = alpha 174 | self.max_soft = max_soft 175 | self.index = None 176 | self.lam = None 177 | 178 | def __call__(self, x): 179 | raise NotImplementedError 180 | 181 | def loss(self, *args, **kwargs): 182 | raise NotImplementedError -------------------------------------------------------------------------------- /datafree/utils/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=(DEFAULT_BLOCK_INDEX,), 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | Parameters 39 | ---------- 40 | output_blocks : list of int 41 | Indices of blocks to return features of. Possible values are: 42 | - 0: corresponds to output of first max pooling 43 | - 1: corresponds to output of second max pooling 44 | - 2: corresponds to output which is fed to aux classifier 45 | - 3: corresponds to output of final average pooling 46 | resize_input : bool 47 | If true, bilinearly resizes input to width and height 299 before 48 | feeding input to model. As the network without fully connected 49 | layers is fully convolutional, it should be able to handle inputs 50 | of arbitrary size, so resizing might not be strictly needed 51 | normalize_input : bool 52 | If true, scales the input from range (0, 1) to the range the 53 | pretrained Inception network expects, namely (-1, 1) 54 | requires_grad : bool 55 | If true, parameters of the model require gradients. Possibly useful 56 | for finetuning the network 57 | use_fid_inception : bool 58 | If true, uses the pretrained Inception model used in Tensorflow's 59 | FID implementation. If false, uses the pretrained Inception model 60 | available in torchvision. The FID Inception model has different 61 | weights and a slightly different structure from torchvision's 62 | Inception model. If you want to compute FID scores, you are 63 | strongly advised to set this parameter to true to get comparable 64 | results. 65 | """ 66 | super(InceptionV3, self).__init__() 67 | 68 | self.resize_input = resize_input 69 | self.normalize_input = normalize_input 70 | self.output_blocks = sorted(output_blocks) 71 | self.last_needed_block = max(output_blocks) 72 | 73 | assert self.last_needed_block <= 3, \ 74 | 'Last possible output block index is 3' 75 | 76 | self.blocks = nn.ModuleList() 77 | 78 | if use_fid_inception: 79 | inception = fid_inception_v3() 80 | else: 81 | inception = _inception_v3(pretrained=True) 82 | 83 | # Block 0: input to maxpool1 84 | block0 = [ 85 | inception.Conv2d_1a_3x3, 86 | inception.Conv2d_2a_3x3, 87 | inception.Conv2d_2b_3x3, 88 | nn.MaxPool2d(kernel_size=3, stride=2) 89 | ] 90 | self.blocks.append(nn.Sequential(*block0)) 91 | 92 | # Block 1: maxpool1 to maxpool2 93 | if self.last_needed_block >= 1: 94 | block1 = [ 95 | inception.Conv2d_3b_1x1, 96 | inception.Conv2d_4a_3x3, 97 | nn.MaxPool2d(kernel_size=3, stride=2) 98 | ] 99 | self.blocks.append(nn.Sequential(*block1)) 100 | 101 | # Block 2: maxpool2 to aux classifier 102 | if self.last_needed_block >= 2: 103 | block2 = [ 104 | inception.Mixed_5b, 105 | inception.Mixed_5c, 106 | inception.Mixed_5d, 107 | inception.Mixed_6a, 108 | inception.Mixed_6b, 109 | inception.Mixed_6c, 110 | inception.Mixed_6d, 111 | inception.Mixed_6e, 112 | ] 113 | self.blocks.append(nn.Sequential(*block2)) 114 | 115 | # Block 3: aux classifier to final avgpool 116 | if self.last_needed_block >= 3: 117 | block3 = [ 118 | inception.Mixed_7a, 119 | inception.Mixed_7b, 120 | inception.Mixed_7c, 121 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 122 | ] 123 | self.blocks.append(nn.Sequential(*block3)) 124 | 125 | for param in self.parameters(): 126 | param.requires_grad = requires_grad 127 | 128 | def forward(self, inp): 129 | """Get Inception feature maps 130 | Parameters 131 | ---------- 132 | inp : torch.autograd.Variable 133 | Input tensor of shape Bx3xHxW. Values are expected to be in 134 | range (0, 1) 135 | Returns 136 | ------- 137 | List of torch.autograd.Variable, corresponding to the selected output 138 | block, sorted ascending by index 139 | """ 140 | outp = [] 141 | x = inp 142 | 143 | if self.resize_input: 144 | x = F.interpolate(x, 145 | size=(299, 299), 146 | mode='bilinear', 147 | align_corners=False) 148 | 149 | if self.normalize_input: 150 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 151 | 152 | for idx, block in enumerate(self.blocks): 153 | x = block(x) 154 | if idx in self.output_blocks: 155 | outp.append(x) 156 | 157 | if idx == self.last_needed_block: 158 | break 159 | 160 | return outp 161 | 162 | 163 | def _inception_v3(*args, **kwargs): 164 | """Wraps `torchvision.models.inception_v3` 165 | Skips default weight inititialization if supported by torchvision version. 166 | See https://github.com/mseitzer/pytorch-fid/issues/28. 167 | """ 168 | try: 169 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 170 | except ValueError: 171 | # Just a caution against weird version strings 172 | version = (0,) 173 | 174 | if version >= (0, 6): 175 | kwargs['init_weights'] = False 176 | 177 | return torchvision.models.inception_v3(*args, **kwargs) 178 | 179 | 180 | def fid_inception_v3(): 181 | """Build pretrained Inception model for FID computation 182 | The Inception model for FID computation uses a different set of weights 183 | and has a slightly different structure than torchvision's Inception. 184 | This method first constructs torchvision's Inception and then patches the 185 | necessary parts that are different in the FID Inception model. 186 | """ 187 | inception = _inception_v3(num_classes=1008, 188 | aux_logits=False, 189 | pretrained=False) 190 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 191 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 192 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 193 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 194 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 195 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 196 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 197 | inception.Mixed_7b = FIDInceptionE_1(1280) 198 | inception.Mixed_7c = FIDInceptionE_2(2048) 199 | 200 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 201 | inception.load_state_dict(state_dict) 202 | return inception 203 | 204 | 205 | class FIDInceptionA(torchvision.models.inception.InceptionA): 206 | """InceptionA block patched for FID computation""" 207 | def __init__(self, in_channels, pool_features): 208 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 209 | 210 | def forward(self, x): 211 | branch1x1 = self.branch1x1(x) 212 | 213 | branch5x5 = self.branch5x5_1(x) 214 | branch5x5 = self.branch5x5_2(branch5x5) 215 | 216 | branch3x3dbl = self.branch3x3dbl_1(x) 217 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 218 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 219 | 220 | # Patch: Tensorflow's average pool does not use the padded zero's in 221 | # its average calculation 222 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 223 | count_include_pad=False) 224 | branch_pool = self.branch_pool(branch_pool) 225 | 226 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 227 | return torch.cat(outputs, 1) 228 | 229 | 230 | class FIDInceptionC(torchvision.models.inception.InceptionC): 231 | """InceptionC block patched for FID computation""" 232 | def __init__(self, in_channels, channels_7x7): 233 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 234 | 235 | def forward(self, x): 236 | branch1x1 = self.branch1x1(x) 237 | 238 | branch7x7 = self.branch7x7_1(x) 239 | branch7x7 = self.branch7x7_2(branch7x7) 240 | branch7x7 = self.branch7x7_3(branch7x7) 241 | 242 | branch7x7dbl = self.branch7x7dbl_1(x) 243 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 244 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 245 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 246 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 247 | 248 | # Patch: Tensorflow's average pool does not use the padded zero's in 249 | # its average calculation 250 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 251 | count_include_pad=False) 252 | branch_pool = self.branch_pool(branch_pool) 253 | 254 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 255 | return torch.cat(outputs, 1) 256 | 257 | 258 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 259 | """First InceptionE block patched for FID computation""" 260 | def __init__(self, in_channels): 261 | super(FIDInceptionE_1, self).__init__(in_channels) 262 | 263 | def forward(self, x): 264 | branch1x1 = self.branch1x1(x) 265 | 266 | branch3x3 = self.branch3x3_1(x) 267 | branch3x3 = [ 268 | self.branch3x3_2a(branch3x3), 269 | self.branch3x3_2b(branch3x3), 270 | ] 271 | branch3x3 = torch.cat(branch3x3, 1) 272 | 273 | branch3x3dbl = self.branch3x3dbl_1(x) 274 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 275 | branch3x3dbl = [ 276 | self.branch3x3dbl_3a(branch3x3dbl), 277 | self.branch3x3dbl_3b(branch3x3dbl), 278 | ] 279 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 280 | 281 | # Patch: Tensorflow's average pool does not use the padded zero's in 282 | # its average calculation 283 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 284 | count_include_pad=False) 285 | branch_pool = self.branch_pool(branch_pool) 286 | 287 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 288 | return torch.cat(outputs, 1) 289 | 290 | 291 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 292 | """Second InceptionE block patched for FID computation""" 293 | def __init__(self, in_channels): 294 | super(FIDInceptionE_2, self).__init__(in_channels) 295 | 296 | def forward(self, x): 297 | branch1x1 = self.branch1x1(x) 298 | 299 | branch3x3 = self.branch3x3_1(x) 300 | branch3x3 = [ 301 | self.branch3x3_2a(branch3x3), 302 | self.branch3x3_2b(branch3x3), 303 | ] 304 | branch3x3 = torch.cat(branch3x3, 1) 305 | 306 | branch3x3dbl = self.branch3x3dbl_1(x) 307 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 308 | branch3x3dbl = [ 309 | self.branch3x3dbl_3a(branch3x3dbl), 310 | self.branch3x3dbl_3b(branch3x3dbl), 311 | ] 312 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 313 | 314 | # Patch: The FID Inception model uses max pooling instead of average 315 | # pooling. This is likely an error in this specific Inception 316 | # implementation, as other Inception models use average pooling here 317 | # (which matches the description in the paper). 318 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 319 | branch_pool = self.branch_pool(branch_pool) 320 | 321 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 322 | return torch.cat(outputs, 1) -------------------------------------------------------------------------------- /datafree/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os, sys 3 | from termcolor import colored 4 | 5 | 6 | class _ColorfulFormatter(logging.Formatter): 7 | def __init__(self, *args, **kwargs): 8 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 9 | 10 | def formatMessage(self, record): 11 | log = super(_ColorfulFormatter, self).formatMessage(record) 12 | 13 | if record.levelno == logging.WARNING: 14 | prefix = colored("WARNING", "yellow", attrs=["blink"]) 15 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 16 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 17 | else: 18 | return log 19 | 20 | return prefix + " " + log 21 | 22 | def get_logger(name='train', output=None, color=True): 23 | logger = logging.getLogger(name) 24 | logger.setLevel(logging.DEBUG) 25 | logger.propagate = False 26 | 27 | # STDOUT 28 | stdout_handler = logging.StreamHandler( stream=sys.stdout ) 29 | stdout_handler.setLevel( logging.DEBUG ) 30 | 31 | plain_formatter = logging.Formatter( 32 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" ) 33 | if color: 34 | formatter = _ColorfulFormatter( 35 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 36 | datefmt="%m/%d %H:%M:%S") 37 | else: 38 | formatter = plain_formatter 39 | stdout_handler.setFormatter(formatter) 40 | 41 | logger.addHandler(stdout_handler) 42 | 43 | # FILE 44 | if output is not None: 45 | if output.endswith('.txt') or output.endswith('.log'): 46 | os.makedirs(os.path.dirname(output), exist_ok=True) 47 | filename = output 48 | else: 49 | os.makedirs(output, exist_ok=True) 50 | filename = os.path.join(output, "log.txt") 51 | file_handler = logging.FileHandler(filename) 52 | file_handler.setFormatter(plain_formatter) 53 | file_handler.setLevel(logging.DEBUG) 54 | logger.addHandler(file_handler) 55 | return logger 56 | 57 | 58 | -------------------------------------------------------------------------------- /datafree/utils/pair.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision 3 | from torchvision import datasets 4 | from collections import defaultdict 5 | from torch.utils.data import Sampler, Dataset 6 | 7 | class DatasetWrapper(Dataset): 8 | # Additinoal attributes 9 | # - indices 10 | # - classwise_indices 11 | # - num_classes 12 | # - get_class 13 | 14 | def __init__(self, dataset, indices=None): 15 | self.base_dataset = dataset 16 | if indices is None: 17 | self.indices = list(range(len(dataset))) 18 | else: 19 | self.indices = indices 20 | 21 | # torchvision 0.2.0 compatibility 22 | if torchvision.__version__.startswith('0.2'): 23 | if isinstance(self.base_dataset, datasets.ImageFolder): 24 | self.base_dataset.targets = [s[1] for s in self.base_dataset.imgs] 25 | else: 26 | if self.base_dataset.train: 27 | self.base_dataset.targets = self.base_dataset.train_labels 28 | else: 29 | self.base_dataset.targets = self.base_dataset.test_labels 30 | 31 | self.classwise_indices = defaultdict(list) 32 | for i in range(len(self)): 33 | y = self.base_dataset.targets[self.indices[i]] 34 | self.classwise_indices[y].append(i) 35 | self.num_classes = max(self.classwise_indices.keys())+1 36 | 37 | def __getitem__(self, i): 38 | return self.base_dataset[self.indices[i]] 39 | 40 | def __len__(self): 41 | return len(self.indices) 42 | 43 | def get_class(self, i): 44 | return self.base_dataset.targets[self.indices[i]] 45 | 46 | 47 | class PairBatchSampler(Sampler): 48 | def __init__(self, dataset, batch_size, num_iterations=None): 49 | self.dataset = dataset 50 | self.batch_size = batch_size 51 | self.num_iterations = num_iterations 52 | 53 | def __iter__(self): 54 | indices = list(range(len(self.dataset))) 55 | random.shuffle(indices) 56 | for k in range(len(self)): 57 | if self.num_iterations is None: 58 | offset = k*self.batch_size 59 | batch_indices = indices[offset:offset+self.batch_size] 60 | else: 61 | batch_indices = random.sample(range(len(self.dataset)), 62 | self.batch_size) 63 | 64 | pair_indices = [] 65 | for idx in batch_indices: 66 | y = self.dataset.get_class(idx) 67 | pair_indices.append(random.choice(self.dataset.classwise_indices[y])) 68 | 69 | yield batch_indices + pair_indices 70 | 71 | def __len__(self): 72 | if self.num_iterations is None: 73 | return (len(self.dataset)+self.batch_size-1) // self.batch_size 74 | else: 75 | return self.num_iterations -------------------------------------------------------------------------------- /datafree/utils/sync_transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * -------------------------------------------------------------------------------- /datafree/utils/vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('agg') 3 | import matplotlib.pyplot as plt 4 | from matplotlib import cm 5 | 6 | from sklearn.manifold import TSNE 7 | import seaborn as sns 8 | import numpy as np 9 | import pandas as pd 10 | import os 11 | import io 12 | 13 | def tsne_features( real_features, real_labels, fake_features, fake_labels, output_file ): 14 | fig = plt.figure(figsize=(10, 10)) 15 | features = np.concatenate( [real_features, fake_features ], axis=0 ) 16 | labels = np.concatenate( [ real_labels, fake_labels ], axis=0 ).reshape(-1, 1) 17 | tsne = TSNE( n_components=2, perplexity=10 ).fit_transform( features ) 18 | df = np.concatenate( [tsne, labels], axis=1 ) 19 | df = pd.DataFrame(df, columns=["x", "y", "label"]) 20 | style = [ 'real' for _ in range(len(real_features)) ] + [ 'fake' for _ in range(len(fake_features))] 21 | sns.scatterplot( x="x", y="y", data=df, hue="label", palette=sns.color_palette("dark", 10), s=50, style=style) 22 | if output_file is not None: 23 | dirname = os.path.dirname( output_file ) 24 | if dirname!='': 25 | os.makedirs( dirname, exist_ok=True ) 26 | plt.savefig( output_file ) 27 | else: 28 | fig.canvas.draw() 29 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 30 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 31 | plt.close() 32 | return img 33 | 34 | 35 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | BIG_NUMBER = 1e12 6 | 7 | 8 | def pdist(e, squared=False, eps=1e-12): 9 | e_square = e.pow(2).sum(dim=1) 10 | prod = e @ e.t() 11 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 12 | 13 | if not squared: 14 | res = res.sqrt() 15 | 16 | res = res.clone() 17 | res[range(len(e)), range(len(e))] = 0 18 | return res 19 | 20 | 21 | def pos_neg_mask(labels): 22 | pos_mask = (labels.unsqueeze(0) == labels.unsqueeze(1)) * \ 23 | (1 - torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device)) 24 | neg_mask = (labels.unsqueeze(0) != labels.unsqueeze(1)) * \ 25 | (1 - torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device)) 26 | 27 | return pos_mask, neg_mask 28 | 29 | class BalancedWeighted(nn.Module): 30 | cut_off = 0.5 31 | nonzero_loss_cutoff = 1.0 32 | """ 33 | Distance Weighted loss assume that embeddings are normalized py 2-norm. 34 | """ 35 | 36 | def __init__(self, dist_func=pdist): 37 | self.dist_func = dist_func 38 | super().__init__() 39 | 40 | def forward(self, embeddings, labels): 41 | with torch.no_grad(): 42 | embeddings = F.normalize(embeddings, dim=1, p=2) 43 | pos_mask, neg_mask = pos_neg_mask(labels) 44 | pos_pair_idx = pos_mask.nonzero() 45 | anchor_idx = pos_pair_idx[:, 0] 46 | pos_idx = pos_pair_idx[:, 1] 47 | 48 | d = embeddings.size(1) 49 | dist = (pdist(embeddings, squared=True) + torch.eye(embeddings.size(0), 50 | device=embeddings.device, dtype=torch.float32)).sqrt() 51 | log_weight = ((2.0 - d) * dist.log() - ((d - 3.0)/2.0) 52 | * (1.0 - 0.25 * (dist * dist)).log()) 53 | weight = (log_weight - log_weight.max(dim=1, 54 | keepdim=True)[0]).exp() 55 | weight = weight * \ 56 | (neg_mask * (dist < self.nonzero_loss_cutoff) * (dist > self.cut_off)).float() 57 | 58 | weight = weight + \ 59 | ((weight.sum(dim=1, keepdim=True) == 0) * neg_mask).float() 60 | weight = weight / (weight.sum(dim=1, keepdim=True)) 61 | weight = weight[anchor_idx] 62 | if torch.max(weight) == 0: 63 | return None, None, None 64 | weight[torch.isnan(weight)] = 0 65 | neg_idx = torch.multinomial(weight, 1).squeeze(1) 66 | 67 | return anchor_idx, pos_idx, neg_idx 68 | 69 | 70 | class DistanceWeighted(nn.Module): 71 | cut_off = 0.5 72 | nonzero_loss_cutoff = 1.4 73 | """ 74 | Distance Weighted loss assume that embeddings are normalized py 2-norm. 75 | """ 76 | 77 | def __init__(self, dist_func=pdist): 78 | self.dist_func = dist_func 79 | super().__init__() 80 | 81 | def forward(self, embeddings, labels): 82 | with torch.no_grad(): 83 | embeddings = F.normalize(embeddings, dim=1, p=2) 84 | pos_mask, neg_mask = pos_neg_mask(labels) 85 | pos_pair_idx = pos_mask.nonzero() 86 | anchor_idx = pos_pair_idx[:, 0] 87 | pos_idx = pos_pair_idx[:, 1] 88 | 89 | d = embeddings.size(1) 90 | dist = (pdist(embeddings, squared=True) + torch.eye(embeddings.size(0), 91 | device=embeddings.device, dtype=torch.float32)).sqrt() 92 | dist = dist.clamp(min=self.cut_off) 93 | log_weight = ((2.0 - d) * dist.log() - ((d - 3.0)/2.0) 94 | * (1.0 - 0.25 * (dist * dist)).log()) 95 | weight = (log_weight - log_weight.max(dim=1, 96 | keepdim=True)[0]).exp() 97 | weight = weight * \ 98 | (neg_mask * (dist < self.nonzero_loss_cutoff)).float() 99 | 100 | weight = weight + \ 101 | ((weight.sum(dim=1, keepdim=True) == 0) * neg_mask).float() 102 | weight = weight / (weight.sum(dim=1, keepdim=True)) 103 | weight = weight[anchor_idx] 104 | if torch.max(weight) == 0: 105 | return None, None, None 106 | weight[torch.isnan(weight)] = 0 107 | neg_idx = torch.multinomial(weight, 1).squeeze(1) 108 | 109 | return anchor_idx, pos_idx, neg_idx 110 | 111 | 112 | class TriLoss(nn.Module): 113 | def __init__(self, p=2, margin=0.2, balanced_sampling=False): 114 | super().__init__() 115 | self.p = p 116 | self.margin = margin 117 | 118 | # update distance function accordingly 119 | if balanced_sampling: 120 | self.sampler = BalancedWeighted() 121 | else: 122 | self.sampler = DistanceWeighted() 123 | self.sampler.dist_func = lambda e: pdist(e, squared=(p == 2)) 124 | self.count = 0 125 | 126 | def forward(self, stu_features, logits, labels, negative=False): 127 | if negative: 128 | anchor_idx, neg_idx, pos_idx = self.sampler(logits, labels) 129 | else: 130 | anchor_idx, pos_idx, neg_idx = self.sampler(logits, labels) 131 | 132 | loss = 0. 133 | if anchor_idx is None: 134 | print('warning: no negative samples found.') 135 | return torch.zeros(1) 136 | self.count += 1 137 | for embeddings in stu_features: 138 | if len(embeddings.shape) > 2: 139 | embeddings = embeddings.mean(dim=(2, 3), keepdim=False) 140 | embeddings = F.normalize(embeddings, p=2, dim=-1) 141 | anchor_embed = embeddings[anchor_idx] 142 | positive_embed = embeddings[pos_idx] 143 | negative_embed = embeddings[neg_idx] 144 | 145 | triloss = F.triplet_margin_loss(anchor_embed, positive_embed, negative_embed, 146 | margin=self.margin, p=self.p, reduction='none') 147 | loss += triloss 148 | 149 | return loss.mean() 150 | 151 | 152 | def prune_fpgm(layers): 153 | 154 | pruned_activations_mask = [] 155 | with torch.no_grad(): 156 | 157 | for layer in layers: 158 | 159 | b, c, h, w = layer.shape 160 | 161 | P = layer.view((b, c, h * w)) 162 | 163 | A = P @ P.transpose(1, 2) 164 | A = torch.sum(A, dim=-1) 165 | max_ = torch.max(A) 166 | min_ = torch.min(A) 167 | A = 1.0 - (A - min_) / (max_ - min_) + 1e-3 168 | 169 | pruned_activations_mask.append(A.to(layer.device)) 170 | 171 | return pruned_activations_mask 172 | 173 | 174 | class CDLoss(nn.Module): 175 | """Channel Distillation Loss""" 176 | 177 | def __init__(self, linears=[]): 178 | super().__init__() 179 | self.linears = linears 180 | 181 | def forward(self, stu_features: list, tea_features: list): 182 | loss = 0. 183 | for i, (s, t) in enumerate(zip(stu_features, tea_features)): 184 | if not self.linears[i] is None: 185 | s = self.linears[i](s) 186 | s = s.mean(dim=(2, 3), keepdim=False) 187 | t = t.mean(dim=(2, 3), keepdim=False) 188 | # loss += F.mse_loss(F.normalize(s, p=2, dim=-1), F.normalize(t, p=2, dim=-1)) 189 | loss += F.mse_loss(s, t) 190 | return loss 191 | 192 | 193 | class GRAMLoss(nn.Module): 194 | """GRAM Loss""" 195 | 196 | def __init__(self, linears=[]): 197 | super().__init__() 198 | self.linears = linears 199 | 200 | def forward(self, stu_features: list, tea_features: list): 201 | loss = 0. 202 | masks = prune_fpgm(tea_features) 203 | for i, s in enumerate(stu_features): 204 | t = tea_features[i] 205 | if not self.linears[i] is None: 206 | s = self.linears[i](s) 207 | b, c = masks[i].shape 208 | m = masks[i].view((b, c, 1, 1)).detach() 209 | loss += torch.mean(torch.pow(s - t, 2) * m) 210 | return loss 211 | -------------------------------------------------------------------------------- /misc/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/RGAL/e94cf7b19bff1c1a517592a9d9bcaf521c768e43/misc/framework.png -------------------------------------------------------------------------------- /train_scratch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | import warnings 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.optim 13 | import torch.multiprocessing as mp 14 | import torch.utils.data 15 | import torch.utils.data.distributed 16 | 17 | import registry 18 | import datafree 19 | 20 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 21 | # Basic Settings 22 | parser.add_argument('--data_root', default='data') 23 | parser.add_argument('--model', default='resnet34_imagenet') 24 | parser.add_argument('--dataset', default='cifar10') 25 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 26 | help='number of total epochs to run') 27 | parser.add_argument('-b', '--batch-size', default=128, type=int, 28 | metavar='N', 29 | help='mini-batch size (default: 256), this is the total ' 30 | 'batch size of all GPUs on the current node when ' 31 | 'using Data Parallel or Distributed Data Parallel') 32 | parser.add_argument('--warm_up_epoches', default=10, type=int, 33 | metavar='WPI', help='warm up epoches') 34 | parser.add_argument('--warm_up_lr', default=0.01, type=int, 35 | metavar='WPI', help='warm up learning rate') 36 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 37 | metavar='LR', help='initial learning rate', dest='lr') 38 | parser.add_argument('--lr_decay_milestones', default="50,75", type=str, 39 | help='milestones for learning rate decay') 40 | parser.add_argument('--evaluate_only', action='store_true', 41 | help='evaluate model on validation set') 42 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 43 | help='path to latest checkpoint (default: none)') 44 | parser.add_argument('--gpu', default=0, type=int, 45 | help='GPU id to use.') 46 | 47 | # Device & FP16 48 | parser.add_argument('--fp16', action='store_true', 49 | help='use fp16') 50 | parser.add_argument('--world-size', default=-1, type=int, 51 | help='number of nodes for distributed training') 52 | parser.add_argument('--rank', default=-1, type=int, 53 | help='node rank for distributed training') 54 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 55 | help='url used to set up distributed training') 56 | parser.add_argument('--dist-backend', default='nccl', type=str, 57 | help='distributed backend') 58 | parser.add_argument('--multiprocessing-distributed', action='store_true', 59 | help='Use multi-processing distributed training to launch ' 60 | 'N processes per node, which has N GPUs. This is the ' 61 | 'fastest way to use PyTorch for either single node or ' 62 | 'multi node data parallel training') 63 | 64 | # Misc 65 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 66 | help='number of data loading workers (default: 4)') 67 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 68 | help='manual epoch number (useful on restarts)') 69 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 70 | help='momentum') 71 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 72 | help='use pre-trained model') 73 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 74 | metavar='W', help='weight decay (default: 1e-4)', 75 | dest='weight_decay') 76 | parser.add_argument('-p', '--print-freq', default=0, type=int, 77 | metavar='N', help='print frequency (default: 0)') 78 | parser.add_argument('--seed', default=None, type=int, 79 | help='seed for initializing training.') 80 | 81 | best_acc1 = 0 82 | 83 | 84 | def main(): 85 | args = parser.parse_args() 86 | if args.seed is not None: 87 | random.seed(args.seed) 88 | torch.manual_seed(args.seed) 89 | cudnn.deterministic = True 90 | warnings.warn('You have chosen to seed training. ' 91 | 'This will turn on the CUDNN deterministic setting, ' 92 | 'which can slow down your training considerably! ' 93 | 'You may see unexpected behavior when restarting ' 94 | 'from checkpoints.') 95 | if args.gpu is not None: 96 | warnings.warn('You have chosen a specific GPU. This will completely ' 97 | 'disable data parallelism.') 98 | if args.dist_url == "env://" and args.world_size == -1: 99 | args.world_size = int(os.environ["WORLD_SIZE"]) 100 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 101 | args.ngpus_per_node = ngpus_per_node = torch.cuda.device_count() 102 | if args.multiprocessing_distributed: 103 | # Since we have ngpus_per_node processes per node, the total world_size 104 | # needs to be adjusted accordingly 105 | args.world_size = ngpus_per_node * args.world_size 106 | # Use torch.multiprocessing.spawn to launch distributed processes: the 107 | # main_worker process function 108 | mp.spawn(main_worker, nprocs=ngpus_per_node, 109 | args=(ngpus_per_node, args)) 110 | else: 111 | # Simply call main_worker function 112 | main_worker(args.gpu, ngpus_per_node, args) 113 | 114 | 115 | def main_worker(gpu, ngpus_per_node, args): 116 | global best_acc1 117 | args.gpu = gpu 118 | 119 | ############################################ 120 | # GPU and FP16 121 | ############################################ 122 | if args.gpu is not None: 123 | print("Use GPU: {} for training".format(args.gpu)) 124 | if args.distributed: 125 | if args.dist_url == "env://" and args.rank == -1: 126 | args.rank = int(os.environ["RANK"]) 127 | if args.multiprocessing_distributed: 128 | # For multiprocessing distributed training, rank needs to be the 129 | # global rank among all the processes 130 | args.rank = args.rank * ngpus_per_node + gpu 131 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 132 | world_size=args.world_size, rank=args.rank) 133 | if args.fp16: 134 | from torch.cuda.amp import autocast, GradScaler 135 | args.scaler = GradScaler() if args.fp16 else None 136 | args.autocast = autocast 137 | else: 138 | args.autocast = datafree.utils.dummy_ctx 139 | 140 | ############################################ 141 | # Logger 142 | ############################################ 143 | log_name = 'R%d-%s-%s' % (args.rank, args.dataset, 144 | args.model) if args.multiprocessing_distributed else '%s-%s' % (args.dataset, args.model) 145 | args.logger = datafree.utils.logger.get_logger( 146 | log_name, output='checkpoints/scratch/log-%s-%s.txt' % (args.dataset, args.model)) 147 | if args.rank <= 0: 148 | # print args 149 | for k, v in datafree.utils.flatten_dict(vars(args)).items(): 150 | args.logger.info("%s: %s" % (k, v)) 151 | 152 | ############################################ 153 | # Setup dataset 154 | ############################################ 155 | num_classes, train_dataset, val_dataset = registry.get_dataset( 156 | name=args.dataset, data_root=args.data_root) 157 | cudnn.benchmark = True 158 | if args.distributed: 159 | train_sampler = torch.utils.data.distributed.DistributedSampler( 160 | train_dataset) 161 | else: 162 | train_sampler = None 163 | train_loader = torch.utils.data.DataLoader( 164 | train_dataset, batch_size=args.batch_size, shuffle=( 165 | train_sampler is None), 166 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 167 | val_loader = torch.utils.data.DataLoader( 168 | val_dataset, 169 | batch_size=args.batch_size, shuffle=False, 170 | num_workers=args.workers, pin_memory=True) 171 | evaluator = datafree.evaluators.classification_evaluator(val_loader) 172 | args.current_epoch = 0 173 | 174 | ############################################ 175 | # Setup models and datasets 176 | ############################################ 177 | model = registry.get_model( 178 | args.model, num_classes=num_classes, pretrained=args.pretrained) 179 | if not torch.cuda.is_available(): 180 | print('using CPU, this will be slow') 181 | elif args.distributed: 182 | # For multiprocessing distributed, DistributedDataParallel constructor 183 | # should always set the single device scope, otherwise, 184 | # DistributedDataParallel will use all available devices. 185 | if args.gpu is not None: 186 | torch.cuda.set_device(args.gpu) 187 | model.cuda(args.gpu) 188 | # When using a single GPU per process and per 189 | # DistributedDataParallel, we need to divide the batch size 190 | # ourselves based on the total number of GPUs we have 191 | args.batch_size = int(args.batch_size / ngpus_per_node) 192 | args.workers = int( 193 | (args.workers + ngpus_per_node - 1) / ngpus_per_node) 194 | model = torch.nn.parallel.DistributedDataParallel( 195 | model, device_ids=[args.gpu]) 196 | else: 197 | model.cuda() 198 | # DistributedDataParallel will divide and allocate batch_size to all 199 | # available GPUs if device_ids are not set 200 | model = torch.nn.parallel.DistributedDataParallel(model) 201 | elif args.gpu is not None: 202 | torch.cuda.set_device(args.gpu) 203 | model = model.cuda(args.gpu) 204 | else: 205 | # DataParallel will divide and allocate batch_size to all available GPUs 206 | model = torch.nn.DataParallel(model).cuda() 207 | 208 | ############################################ 209 | # Setup optimizer 210 | ############################################ 211 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 212 | optimizer = torch.optim.SGD(model.parameters( 213 | ), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 214 | milestones = [int(ms) for ms in args.lr_decay_milestones.split(',')] 215 | # warm_up_with_multistep_lr 216 | def warm_up_with_multistep_lr(epoch): return ( 217 | epoch+1) / args.warm_up_epoches if epoch < args.warm_up_epoches else 0.1**len([m for m in milestones if m <= epoch]) 218 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 219 | optimizer, milestones=milestones, gamma=0.1) 220 | 221 | scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=warm_up_with_multistep_lr) 222 | 223 | ############################################ 224 | # Resume 225 | ############################################ 226 | if args.resume: 227 | if os.path.isfile(args.resume): 228 | print("=> loading checkpoint '{}'".format(args.resume)) 229 | if args.gpu is None: 230 | checkpoint = torch.load(args.resume) 231 | else: 232 | # Map model to be loaded to specified single gpu. 233 | loc = 'cuda:{}'.format(args.gpu) 234 | checkpoint = torch.load(args.resume, map_location=loc) 235 | 236 | if isinstance(model, nn.Module): 237 | model.load_state_dict(checkpoint['state_dict']) 238 | else: 239 | model.module.load_state_dict(checkpoint['state_dict']) 240 | 241 | try: 242 | if 'best_acc1' in checkpoint: 243 | best_acc1 = checkpoint['best_acc1'] 244 | args.start_epoch = checkpoint['epoch'] 245 | optimizer.load_state_dict(checkpoint['optimizer']) 246 | scheduler.load_state_dict(checkpoint['scheduler']) 247 | except: 248 | print("Fails to load additional information") 249 | print("[!] loaded checkpoint '{}' (epoch {} acc {})" 250 | .format(args.resume, checkpoint['epoch'], best_acc1)) 251 | else: 252 | print("[!] no checkpoint found at '{}'".format(args.resume)) 253 | 254 | ############################################ 255 | # Evaluate 256 | ############################################ 257 | if args.evaluate_only: 258 | model.eval() 259 | eval_results = evaluator(model, device=args.gpu) 260 | (acc1, acc5), val_loss = eval_results['Acc'], eval_results['Loss'] 261 | print('[Eval] Acc@1={acc1:.4f} Acc@5={acc5:.4f} Loss={loss:.4f}'.format( 262 | acc1=acc1, acc5=acc5, loss=val_loss)) 263 | return 264 | 265 | ############################################ 266 | # Train Loop 267 | ############################################ 268 | for epoch in range(args.start_epoch, args.epochs): 269 | if args.distributed: 270 | train_sampler.set_epoch(epoch) 271 | args.current_epoch = epoch 272 | train(train_loader, model, criterion, optimizer, args) 273 | model.eval() 274 | eval_results = evaluator(model, device=args.gpu) 275 | (acc1, acc5), val_loss = eval_results['Acc'], eval_results['Loss'] 276 | args.logger.info('[Eval] Epoch={current_epoch} Acc@1={acc1:.4f} Acc@5={acc5:.4f} Loss={loss:.4f} Lr={lr:.4f}' 277 | .format(current_epoch=args.current_epoch, acc1=acc1, acc5=acc5, loss=val_loss, lr=optimizer.param_groups[0]['lr'])) 278 | scheduler.step() 279 | is_best = acc1 > best_acc1 280 | best_acc1 = max(acc1, best_acc1) 281 | _best_ckpt = 'checkpoints/scratch/%s_%s.pth' % ( 282 | args.dataset, args.model) 283 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 284 | and args.rank % ngpus_per_node == 0): 285 | save_checkpoint({ 286 | 'epoch': epoch + 1, 287 | 'arch': args.model, 288 | 'state_dict': model.state_dict(), 289 | 'best_acc1': float(best_acc1), 290 | 'optimizer': optimizer.state_dict(), 291 | 'scheduler': scheduler.state_dict() 292 | }, is_best, _best_ckpt) 293 | if args.rank <= 0: 294 | args.logger.info("Best: %.4f" % best_acc1) 295 | 296 | 297 | def train(train_loader, model, criterion, optimizer, args): 298 | global best_acc1 299 | loss_metric = datafree.metrics.RunningLoss( 300 | nn.CrossEntropyLoss(reduction='sum')) 301 | acc_metric = datafree.metrics.TopkAccuracy(topk=(1, 5)) 302 | model.train() 303 | for i, (images, target) in enumerate(train_loader): 304 | if args.gpu is not None: 305 | images = images.cuda(args.gpu, non_blocking=True) 306 | if torch.cuda.is_available(): 307 | target = target.cuda(args.gpu, non_blocking=True) 308 | with args.autocast(enabled=args.fp16): 309 | output = model(images) 310 | loss = criterion(output, target) 311 | # measure accuracy and record loss 312 | acc_metric.update(output, target) 313 | loss_metric.update(output, target) 314 | optimizer.zero_grad() 315 | if args.fp16: 316 | scaler = args.scaler 317 | scaler.scale(loss).backward() 318 | scaler.step(optimizer) 319 | scaler.update() 320 | else: 321 | loss.backward() 322 | optimizer.step() 323 | if args.print_freq > 0 and i % args.print_freq == 0: 324 | (train_acc1, train_acc5), train_loss = acc_metric.get_results( 325 | ), loss_metric.get_results() 326 | args.logger.info('[Train] Epoch={current_epoch} Iter={i}/{total_iters}, train_acc@1={train_acc1:.4f}, train_acc@5={train_acc5:.4f}, train_Loss={train_loss:.4f}, Lr={lr:.4f}' 327 | .format(current_epoch=args.current_epoch, i=i, total_iters=len(train_loader), train_acc1=train_acc1, train_acc5=train_acc5, train_loss=train_loss, lr=optimizer.param_groups[0]['lr'])) 328 | loss_metric.reset(), acc_metric.reset() 329 | 330 | 331 | def save_checkpoint(state, is_best, filename='checkpoint.pth'): 332 | if is_best: 333 | torch.save(state, filename) 334 | 335 | 336 | if __name__ == '__main__': 337 | main() 338 | --------------------------------------------------------------------------------