├── README.md
├── best_model
├── cifar100_wrn40_2.pth
└── cifar10_wrn40_2.pth
├── datafree
├── __init__.py
├── criterions.py
├── datasets
│ ├── __init__.py
│ ├── nyu.py
│ ├── tiny_imagenet.py
│ └── utils.py
├── evaluators.py
├── hooks.py
├── metrics
│ ├── __init__.py
│ ├── accuracy.py
│ ├── confusion_matrix.py
│ ├── running_average.py
│ └── stream_metrics.py
├── models
│ ├── __init__.py
│ ├── classifiers
│ │ ├── __init__.py
│ │ ├── lenet.py
│ │ ├── mobilenetv2.py
│ │ ├── resnet.py
│ │ ├── resnet_in.py
│ │ ├── resnet_tiny.py
│ │ ├── shufflenetv2.py
│ │ ├── vgg.py
│ │ └── wresnet.py
│ ├── deeplab
│ │ ├── __init__.py
│ │ ├── _deeplab.py
│ │ ├── backbone
│ │ │ ├── __init__.py
│ │ │ ├── mobilenetv2.py
│ │ │ └── resnet.py
│ │ ├── modeling.py
│ │ └── utils.py
│ ├── generator.py
│ └── stylegan_generator.py
├── rep_transfer.py
├── synthesis
│ ├── __init__.py
│ ├── base.py
│ ├── contrastive.py
│ └── triplet.py
└── utils
│ ├── __init__.py
│ ├── _utils.py
│ ├── fmix.py
│ ├── inception.py
│ ├── logger.py
│ ├── pair.py
│ ├── sync_transforms
│ ├── __init__.py
│ ├── functional.py
│ └── transforms.py
│ └── vis.py
├── losses.py
├── main.py
├── misc
└── framework.png
├── registry.py
└── train_scratch.py
/README.md:
--------------------------------------------------------------------------------
1 | ## News
2 | * `2024/12/20` We release the code for the *data-free knowledge distillation* tasks.
3 |
4 | # RGAL
5 |
6 | This is a PyTorch implementation of the following paper:
7 |
8 | **Relation-Guided Adversarial Learning for Data-Free Knowledge Transfer**, IJCV 2024.
9 |
10 | Yingping Liang and Ying Fu
11 |
12 | [Paper](https://link.springer.com/article/10.1007/s11263-024-02303-4)
13 |
14 |
15 |
16 | **Abstract**: *Data-free knowledge distillation transfers knowledge by recovering training data from a pre-trained model. Despite the recent success of seeking global data diversity, the diversity within each class and the similarity among different classes are largely overlooked, resulting in data homogeneity and limited performance. In this paper, we introduce a novel Relation-Guided Adversarial Learning method with triplet losses, which solves the homogeneity problem from two aspects. To be specific, our method aims to promote both intra-class diversity and inter-class confusion of the generated samples. To this end, we design two phases, an image synthesis phase and a student training phase. In the image synthesis phase, we construct an optimization process to push away samples with the same labels and pull close samples with different labels, leading to intra-class diversity and inter-class confusion, respectively. Then, in the student training phase, we perform an opposite optimization, which adversarially attempts to reduce the distance of samples of the same classes and enlarge the distance of samples of different classes. To mitigate the conflict of seeking high global diversity and keeping inter-class confusing, we propose a focal weighted sampling strategy by selecting the negative in the triplets unevenly within a finite range of distance. RGAL shows significant improvement over previous state-of-the-art methods in accuracy and data efficiency. Besides, RGAL can be inserted into state-of-the-art methods on various data-free knowledge transfer applications. Experiments on various benchmarks demonstrate the effectiveness and generalizability of our proposed method on various tasks, specially data-free knowledge distillation, data-free quantization, and non-exemplar incremental learning.*
17 |
18 |
19 |
20 |
21 | https://github.com/user-attachments/assets/eb78306f-1fbe-465a-9996-7315716f0b55
22 |
23 |
24 |
25 |
26 |
27 | ## Instillation
28 |
29 | ```
30 | conda create -n rgal python=3.9
31 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
32 | pip install scipy tqdm pillow kornia
33 | ```
34 |
35 | ## Run
36 |
37 | The dataset (CIFAR-10/-100) will be downloaded automatically when running.
38 |
39 | We provide a running script:
40 | ```
41 | python main.py \
42 | --epochs 200 \
43 | --dataset cifar10 \
44 | --batch_size 128 \
45 | --synthesis_batch_size 256 \
46 | --teacher wrn40_2 \
47 | --student wrn16_1 \
48 | --lr 0.1 \
49 | --kd_steps 400 \
50 | --ep_steps 400 \
51 | --g_steps 400 \
52 | --lr_g 1e-3 \
53 | --adv 1.0 \
54 | --bn 1.0 \
55 | --oh 1.0 \
56 | --act 0.001 \
57 | --gpu 0 \
58 | --seed 0 \
59 | --T 20 \
60 | --save_dir run/scratch1 \
61 | --log_tag scratch1 \
62 | --cd_loss 0.1 \
63 | --gram_loss 0 \
64 | --teacher_weights best_model/cifar10_wrn40_2.pth \
65 | --custom_steps 1.0 \
66 | --print_freq 50 \
67 | --triplet_target student \
68 | --pair_sample \
69 | --striplet_feature global \
70 | --start_layer 2 \
71 | --triplet 0.1 \
72 | --striplet 0.1 \
73 | --balanced_sampling \
74 | --balance 0.1
75 | ```
76 |
77 | where "--triplet" and "--striplet" indicates the loss weights of our proposed in the data generation stage and distillation stage, separately.
78 |
79 | To running our method on different teacher and student models, modify "--teacher" and "--student wrn16_1"
80 |
81 | "--balanced_sampling" indicates the paired sampling strategy as in our paper.
82 |
83 | Pretrained checkpoints for examples are available at (best_model)[https://github.com/Sharpiless/RGAL/tree/main/best_model].
84 |
85 | 
86 |
87 |
88 | ## Visualization
89 |
90 | Please refer to (ZSKT)[https://github.com/polo5/ZeroShotKnowledgeTransfer].
91 |
92 | ## License and Citation
93 | This repository can only be used for personal/research/non-commercial purposes.
94 | Please cite the following paper if this model helps your research:
95 |
96 | ```
97 | @article{liang2024relation,
98 | title={Relation-Guided Adversarial Learning for Data-Free Knowledge Transfer},
99 | author={Liang, Yingping and Fu, Ying},
100 | journal={International Journal of Computer Vision},
101 | pages={1--18},
102 | year={2024},
103 | publisher={Springer}
104 | }
105 | ```
106 |
107 | ## Acknowledgments
108 | * The code for inference and training is heavily borrowed from [CMI](https://github.com/zju-vipa/CMI), we thank the author for their great effort.
109 |
--------------------------------------------------------------------------------
/best_model/cifar100_wrn40_2.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/RGAL/e94cf7b19bff1c1a517592a9d9bcaf521c768e43/best_model/cifar100_wrn40_2.pth
--------------------------------------------------------------------------------
/best_model/cifar10_wrn40_2.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/RGAL/e94cf7b19bff1c1a517592a9d9bcaf521c768e43/best_model/cifar10_wrn40_2.pth
--------------------------------------------------------------------------------
/datafree/__init__.py:
--------------------------------------------------------------------------------
1 | from . import criterions, utils, metrics, hooks, rep_transfer, evaluators, synthesis, datasets
2 |
--------------------------------------------------------------------------------
/datafree/criterions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 | def kldiv(logits, targets, T=1.0, reduction='batchmean'):
6 | q = F.log_softmax(logits/T, dim=1)
7 | p = F.softmax(targets/T, dim=1)
8 | return F.kl_div(q, p, reduction=reduction) * (T*T)
9 |
10 |
11 | class KLDiv(nn.Module):
12 | def __init__(self, T=1.0, reduction='batchmean'):
13 | super().__init__()
14 | self.T = T
15 | self.reduction = reduction
16 |
17 | def forward(self, logits, targets):
18 | return kldiv(logits, targets, T=self.T, reduction=self.reduction)
19 |
20 | def jsdiv( logits, targets, T=1.0, reduction='batchmean' ):
21 | P = F.softmax(logits / T, dim=1)
22 | Q = F.softmax(targets / T, dim=1)
23 | M = 0.5 * (P + Q)
24 | P = torch.clamp(P, 0.01, 0.99)
25 | Q = torch.clamp(Q, 0.01, 0.99)
26 | M = torch.clamp(M, 0.01, 0.99)
27 | return 0.5 * F.kl_div(torch.log(P), M, reduction=reduction) + 0.5 * F.kl_div(torch.log(Q), M, reduction=reduction)
28 |
29 | def cross_entropy(logits, targets, reduction='mean'):
30 | return F.cross_entropy(logits, targets, reduction=reduction)
31 |
32 | def class_balance_loss(logits):
33 | prob = torch.softmax(logits, dim=1)
34 | avg_prob = prob.mean(dim=0)
35 | return (avg_prob * torch.log(avg_prob)).sum()
36 |
37 | def onehot_loss(logits, targets=None):
38 | if targets is None:
39 | targets = logits.max(1)[1]
40 | return cross_entropy(logits, targets)
41 |
42 | def get_image_prior_losses(inputs_jit):
43 | # COMPUTE total variation regularization loss
44 | diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:]
45 | diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :]
46 | diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:]
47 | diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:]
48 | #loss_var_l2 = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4)
49 | loss_var_l1 = (diff1.abs() / 255.0).mean() + (diff2.abs() / 255.0).mean() + (
50 | diff3.abs() / 255.0).mean() + (diff4.abs() / 255.0).mean()
51 | loss_var_l1 = loss_var_l1 * 255.0
52 | return loss_var_l1
--------------------------------------------------------------------------------
/datafree/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .nyu import NYUv2
2 | from .tiny_imagenet import TinyImageNet
--------------------------------------------------------------------------------
/datafree/datasets/nyu.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/VainF/nyuv2-python-toolkit
2 | import os
3 | import torch
4 | import torch.utils.data as data
5 | from PIL import Image
6 | from scipy.io import loadmat
7 | import numpy as np
8 | import glob
9 | from torchvision import transforms
10 | from torchvision.datasets import VisionDataset
11 | import random
12 |
13 | from .utils import colormap
14 |
15 | class NYUv2(VisionDataset):
16 | """NYUv2 dataset
17 | See https://github.com/VainF/nyuv2-python-toolkit for more details.
18 |
19 | Args:
20 | root (string): Root directory path.
21 | split (string, optional): 'train' for training set, and 'test' for test set. Default: 'train'.
22 | target_type (string, optional): Type of target to use, ``semantic``, ``depth`` or ``normal``.
23 | num_classes (int, optional): The number of classes, must be 40 or 13. Default:13.
24 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version.
25 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
26 | transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version.
27 | """
28 | cmap = colormap()
29 | def __init__(self,
30 | root,
31 | split='train',
32 | target_type='semantic',
33 | num_classes=13,
34 | transforms=None,
35 | transform=None,
36 | target_transform=None):
37 | super( NYUv2, self ).__init__(root, transforms=transforms, transform=transform, target_transform=target_transform)
38 | assert(split in ('train', 'test'))
39 |
40 | self.root = root
41 | self.split = split
42 | self.target_type = target_type
43 | self.num_classes = num_classes
44 |
45 | split_mat = loadmat(os.path.join(self.root, 'splits.mat'))
46 | idxs = split_mat[self.split+'Ndxs'].reshape(-1) - 1
47 |
48 | img_names = os.listdir( os.path.join(self.root, 'image', self.split) )
49 | img_names.sort()
50 | images_dir = os.path.join(self.root, 'image', self.split)
51 | self.images = [os.path.join(images_dir, name) for name in img_names]
52 |
53 | self._is_depth = False
54 | if self.target_type=='semantic':
55 | semantic_dir = os.path.join(self.root, 'seg%d'%self.num_classes, self.split)
56 | self.labels = [os.path.join(semantic_dir, name) for name in img_names]
57 | self.targets = self.labels
58 |
59 | if self.target_type=='depth':
60 | depth_dir = os.path.join(self.root, 'depth', self.split)
61 | self.depths = [os.path.join(depth_dir, name) for name in img_names]
62 | self.targets = self.depths
63 | self._is_depth = True
64 |
65 | if self.target_type=='normal':
66 | normal_dir = os.path.join(self.root, 'normal', self.split)
67 | self.normals = [os.path.join(normal_dir, name) for name in img_names]
68 | self.targets = self.normals
69 |
70 | def __getitem__(self, idx):
71 | image = Image.open(self.images[idx])
72 | target = Image.open(self.targets[idx])
73 | if self.transforms is not None:
74 | image, target = self.transforms( image, target )
75 | return image, target
76 |
77 | def __len__(self):
78 | return len(self.images)
79 |
80 | @classmethod
81 | def decode_fn(cls, mask: np.ndarray):
82 | """decode semantic mask to RGB image"""
83 | mask = mask.astype('uint8') + 1 # 255 => 0
84 | return cls.cmap[mask]
85 |
--------------------------------------------------------------------------------
/datafree/datasets/tiny_imagenet.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import glob
3 | import numpy as np
4 | import os
5 | from torchvision.datasets.folder import pil_loader
6 | from torchvision.datasets.utils import download_and_extract_archive
7 |
8 | class TinyImageNet(Dataset):
9 | def __init__(self, root, split, transform, download=True):
10 |
11 | self.url = "http://cs231n.stanford.edu/tiny-imagenet-200"
12 | self.root = root
13 | if download:
14 | if os.path.exists(f'{self.root}/tiny-imagenet-200/'):
15 | print(f'{self.root}/tiny-imagenet-200/, File already downloaded')
16 | else:
17 | print(f'{self.root}/tiny-imagenet-200/, File isn\'t downloaded')
18 | download_and_extract_archive(self.url, root, filename="tiny-imagenet-200.zip")
19 |
20 | self.root = os.path.join(self.root, "tiny-imagenet-200")
21 | self.train = split == "train"
22 | self.transform = transform
23 | self.ids_string = np.sort(np.loadtxt(f"{self.root}/wnids.txt", "str"))
24 | self.ids = {class_string: i for i, class_string in enumerate(self.ids_string)}
25 | if self.train:
26 | self.paths = glob.glob(f"{self.root}/train/*/images/*")
27 | self.targets = [self.ids[path.split("/")[-3]] for path in self.paths]
28 | else:
29 | self.paths = glob.glob(f"{self.root}/val/*/images/*")
30 | self.targets = [self.ids[path.split("/")[-3]] for path in self.paths]
31 |
32 | def __len__(self):
33 | return len(self.paths)
34 |
35 | def __getitem__(self, idx):
36 | image = pil_loader(self.paths[idx])
37 |
38 | if self.transform is not None:
39 | image = self.transform(image)
40 |
41 | return image, self.targets[idx]
--------------------------------------------------------------------------------
/datafree/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | def colormap(N=256, normalized=False):
5 | def bitget(byteval, idx):
6 | return ((byteval & (1 << idx)) != 0)
7 |
8 | dtype = 'float32' if normalized else 'uint8'
9 | cmap = np.zeros((N, 3), dtype=dtype)
10 | for i in range(N):
11 | r = g = b = 0
12 | c = i
13 | for j in range(8):
14 | r = r | (bitget(c, 0) << 7-j)
15 | g = g | (bitget(c, 1) << 7-j)
16 | b = b | (bitget(c, 2) << 7-j)
17 | c = c >> 3
18 |
19 | cmap[i] = np.array([r, g, b])
20 |
21 | cmap = cmap/255 if normalized else cmap
22 | return cmap
--------------------------------------------------------------------------------
/datafree/evaluators.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch.nn.functional as F
3 | import torch
4 | from . import metrics
5 |
6 | class Evaluator(object):
7 | def __init__(self, metric, dataloader):
8 | self.dataloader = dataloader
9 | self.metric = metric
10 |
11 | def eval(self, model, device=None, progress=False):
12 | self.metric.reset()
13 | with torch.no_grad():
14 | for i, (inputs, targets) in enumerate( tqdm(self.dataloader, disable=not progress) ):
15 | inputs, targets = inputs.to(device), targets.to(device)
16 | outputs = model( inputs )
17 | self.metric.update(outputs, targets)
18 | return self.metric.get_results()
19 |
20 | def __call__(self, *args, **kwargs):
21 | return self.eval(*args, **kwargs)
22 |
23 | class AdvEvaluator(object):
24 | def __init__(self, metric, dataloader, adversary):
25 | self.dataloader = dataloader
26 | self.metric = metric
27 | self.adversary = adversary
28 |
29 | def eval(self, model, device=None, progress=False):
30 | self.metric.reset()
31 | for i, (inputs, targets) in enumerate( tqdm(self.dataloader, disable=not progress) ):
32 | inputs, targets = inputs.to(device), targets.to(device)
33 | inputs = self.adversary.perturb(inputs, targets)
34 | with torch.no_grad():
35 | outputs = model( inputs )
36 | self.metric.update(outputs, targets)
37 | return self.metric.get_results()
38 |
39 | def __call__(self, *args, **kwargs):
40 | return self.eval(*args, **kwargs)
41 |
42 | def classification_evaluator(dataloader):
43 | metric = metrics.MetricCompose({
44 | 'Acc': metrics.TopkAccuracy(),
45 | 'Loss': metrics.RunningLoss(torch.nn.CrossEntropyLoss(reduction='sum'))
46 | })
47 | return Evaluator( metric, dataloader=dataloader)
48 |
49 | def advarsarial_classification_evaluator(dataloader, adversary):
50 | metric = metrics.MetricCompose({
51 | 'Acc': metrics.TopkAccuracy(),
52 | 'Loss': metrics.RunningLoss(torch.nn.CrossEntropyLoss(reduction='sum'))
53 | })
54 | return AdvEvaluator( metric, dataloader=dataloader, adversary=adversary)
55 |
56 |
57 | def segmentation_evaluator(dataloader, num_classes, ignore_idx=255):
58 | cm = metrics.ConfusionMatrix(num_classes, ignore_idx=ignore_idx)
59 | metric = metrics.MetricCompose({
60 | 'mIoU': metrics.mIoU(cm),
61 | 'Acc': metrics.Accuracy(),
62 | 'Loss': metrics.RunningLoss(torch.nn.CrossEntropyLoss(reduction='sum'))
63 | })
64 | return Evaluator( metric, dataloader=dataloader)
--------------------------------------------------------------------------------
/datafree/hooks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def register_hooks(modules):
6 | hooks = []
7 | for m in modules:
8 | hooks.append( FeatureHook(m) )
9 | return hooks
10 |
11 | class InstanceMeanHook(object):
12 | def __init__(self, module):
13 | self.hook = module.register_forward_hook(self.hook_fn)
14 | self.module = module
15 |
16 | def hook_fn(self, module, input, output):
17 | self.instance_mean = torch.mean(input[0], dim=[2, 3])
18 |
19 | def remove(self):
20 | self.hook.remove()
21 |
22 | def __repr__(self):
23 | return ": %s"%(self.module)
24 |
25 | class FeatureHook(object):
26 | def __init__(self, module):
27 | self.hook = module.register_forward_hook(self.hook_fn)
28 | self.module = module
29 |
30 | def hook_fn(self, module, input, output):
31 | self.output = output
32 | self.input = input[0]
33 |
34 | def remove(self):
35 | self.hook.remove()
36 |
37 | def __repr__(self):
38 | return ": %s"%(self.module)
39 |
40 |
41 | class FeatureMeanHook(object):
42 | def __init__(self, module):
43 | self.hook = module.register_forward_hook(self.hook_fn)
44 | self.module = module
45 |
46 | def hook_fn(self, module, input, output):
47 | self.instance_mean = torch.mean(input[0], dim=[2, 3])
48 |
49 | def remove(self):
50 | self.hook.remove()
51 |
52 | def __repr__(self):
53 | return ": %s"%(self.module)
54 |
55 |
56 | class FeatureMeanVarHook():
57 | def __init__(self, module, on_input=True, dim=[0,2,3]):
58 | self.hook = module.register_forward_hook(self.hook_fn)
59 | self.on_input = on_input
60 | self.module = module
61 | self.dim = dim
62 |
63 | def hook_fn(self, module, input, output):
64 | # To avoid inplace modification
65 | if self.on_input:
66 | feature = input[0].clone()
67 | else:
68 | feature = output.clone()
69 | self.var, self.mean = torch.var_mean( feature, dim=self.dim, unbiased=True )
70 |
71 | def remove(self):
72 | self.hook.remove()
73 | self.output=None
74 |
75 |
76 | class DeepInversionHook():
77 | '''
78 | Implementation of the forward hook to track feature statistics and compute a loss on them.
79 | Will compute mean and variance, and will use l2 as a loss
80 | '''
81 | def __init__(self, module):
82 | self.hook = module.register_forward_hook(self.hook_fn)
83 | self.module = module
84 |
85 | def hook_fn(self, module, input, output):
86 | # hook co compute deepinversion's feature distribution regularization
87 | nch = input[0].shape[1]
88 | mean = input[0].mean([0, 2, 3])
89 | var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)
90 | #forcing mean and variance to match between two distributions
91 | #other ways might work better, i.g. KL divergence
92 | r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
93 | module.running_mean.data - mean, 2)
94 | self.r_feature = r_feature
95 |
96 | def remove(self):
97 | self.hook.remove()
--------------------------------------------------------------------------------
/datafree/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .stream_metrics import Metric, MetricCompose
2 | from .accuracy import Accuracy, TopkAccuracy
3 | from .confusion_matrix import ConfusionMatrix, IoU, mIoU
4 | from .running_average import RunningLoss
5 |
6 |
--------------------------------------------------------------------------------
/datafree/metrics/accuracy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from .stream_metrics import Metric
4 | from typing import Callable
5 |
6 | __all__=['Accuracy', 'TopkAccuracy']
7 |
8 | class Accuracy(Metric):
9 | def __init__(self):
10 | self.reset()
11 |
12 | @torch.no_grad()
13 | def update(self, outputs, targets):
14 | outputs = outputs.max(1)[1]
15 | self._correct += ( outputs.view(-1)==targets.view(-1) ).sum()
16 | self._cnt += torch.numel( targets )
17 |
18 | def get_results(self):
19 | return (self._correct / self._cnt * 100.).detach().cpu()
20 |
21 | def reset(self):
22 | self._correct = self._cnt = 0.0
23 |
24 |
25 | class TopkAccuracy(Metric):
26 | def __init__(self, topk=(1, 5)):
27 | self._topk = topk
28 | self.reset()
29 |
30 | @torch.no_grad()
31 | def update(self, outputs, targets):
32 | for k in self._topk:
33 | _, topk_outputs = outputs.topk(k, dim=1, largest=True, sorted=True)
34 | correct = topk_outputs.eq( targets.view(-1, 1).expand_as(topk_outputs) )
35 | self._correct[k] += correct[:, :k].view(-1).float().sum(0).item()
36 | self._cnt += len(targets)
37 |
38 | def get_results(self):
39 | return tuple( self._correct[k] / self._cnt * 100. for k in self._topk )
40 |
41 | def reset(self):
42 | self._correct = {k: 0 for k in self._topk}
43 | self._cnt = 0.0
--------------------------------------------------------------------------------
/datafree/metrics/confusion_matrix.py:
--------------------------------------------------------------------------------
1 | from .stream_metrics import Metric
2 | import torch
3 | from typing import Callable
4 |
5 | class ConfusionMatrix(Metric):
6 | def __init__(self, num_classes, ignore_idx=None):
7 | super(ConfusionMatrix, self).__init__()
8 | self._num_classes = num_classes
9 | self._ignore_idx = ignore_idx
10 | self.reset()
11 |
12 | @torch.no_grad()
13 | def update(self, outputs, targets):
14 | if self.confusion_matrix.device != outputs.device:
15 | self.confusion_matrix = self.confusion_matrix.to(device=outputs.device)
16 | preds = outputs.max(1)[1].flatten()
17 | targets = targets.flatten()
18 | mask = (preds=0)
19 | if self._ignore_idx:
20 | mask = mask & (targets!=self._ignore_idx)
21 | preds, targets = preds[mask], targets[mask]
22 | hist = torch.bincount( self._num_classes * targets + preds,
23 | minlength=self._num_classes ** 2 ).view(self._num_classes, self._num_classes)
24 | self.confusion_matrix += hist
25 |
26 | def get_results(self):
27 | return self.confusion_matrix.detach().cpu()
28 |
29 | def reset(self):
30 | self._cnt = 0
31 | self.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.int64, requires_grad=False)
32 |
33 | class IoU(Metric):
34 | def __init__(self, confusion_matrix: ConfusionMatrix):
35 | self._confusion_matrix = confusion_matrix
36 |
37 | def update(self, outputs, targets):
38 | self._confusion_matrix.update(outputs, targets)
39 |
40 | def reset(self):
41 | self._confusion_matrix.reset()
42 |
43 | def get_results(self):
44 | cm = self._confusion_matrix.get_results()
45 | iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-9)
46 | return iou
47 |
48 | class mIoU(IoU):
49 | def get_results(self):
50 | return super(mIoU, self).get_results().mean()
51 |
--------------------------------------------------------------------------------
/datafree/metrics/running_average.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from .stream_metrics import Metric
4 |
5 | __all__=['Accuracy', 'TopkAccuracy']
6 |
7 | class RunningLoss(Metric):
8 | def __init__(self, loss_fn, is_batch_average=False):
9 | self.reset()
10 | self.loss_fn = loss_fn
11 | self.is_batch_average = is_batch_average
12 |
13 | @torch.no_grad()
14 | def update(self, outputs, targets):
15 | self._accum_loss += self.loss_fn(outputs, targets)
16 | if self.is_batch_average:
17 | self._cnt += 1
18 | else:
19 | self._cnt += len(outputs)
20 |
21 | def get_results(self):
22 | return (self._accum_loss / self._cnt).detach().cpu()
23 |
24 | def reset(self):
25 | self._accum_loss = self._cnt = 0.0
26 |
--------------------------------------------------------------------------------
/datafree/metrics/stream_metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | from abc import ABC, abstractmethod
4 | from typing import Callable, Union, Any, Mapping, Sequence
5 | import numbers
6 | import numpy as np
7 |
8 | class Metric(ABC):
9 | @abstractmethod
10 | def update(self, pred, target):
11 | """ Overridden by subclasses """
12 | raise NotImplementedError()
13 |
14 | @abstractmethod
15 | def get_results(self):
16 | """ Overridden by subclasses """
17 | raise NotImplementedError()
18 |
19 | @abstractmethod
20 | def reset(self):
21 | """ Overridden by subclasses """
22 | raise NotImplementedError()
23 |
24 |
25 | class MetricCompose(dict):
26 | def __init__(self, metric_dict: Mapping):
27 | self._metric_dict = metric_dict
28 |
29 | @property
30 | def metrics(self):
31 | return self._metric_dict
32 |
33 | @torch.no_grad()
34 | def update(self, outputs, targets):
35 | for key, metric in self._metric_dict.items():
36 | if isinstance(metric, Metric):
37 | metric.update(outputs, targets)
38 |
39 | def get_results(self):
40 | results = {}
41 | for key, metric in self._metric_dict.items():
42 | if isinstance(metric, Metric):
43 | results[key] = metric.get_results()
44 | return results
45 |
46 | def reset(self):
47 | for key, metric in self._metric_dict.items():
48 | if isinstance(metric, Metric):
49 | metric.reset()
50 |
51 | def __getitem__(self, name):
52 | return self._metric_dict[name]
53 |
54 |
55 |
--------------------------------------------------------------------------------
/datafree/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import classifiers
2 | from . import generator
3 | from . import deeplab
--------------------------------------------------------------------------------
/datafree/models/classifiers/__init__.py:
--------------------------------------------------------------------------------
1 | from . import lenet, wresnet, vgg, resnet, mobilenetv2, shufflenetv2, resnet_tiny, resnet_in
--------------------------------------------------------------------------------
/datafree/models/classifiers/lenet.py:
--------------------------------------------------------------------------------
1 | # https://github.com/huawei-noah/Data-Efficient-Model-Compression
2 | import torch.nn as nn
3 |
4 | class LeNet5(nn.Module):
5 |
6 | def __init__(self, nc=1, num_classes=10):
7 | super(LeNet5, self).__init__()
8 | self.features = nn.Sequential(
9 | nn.Conv2d(1, 6, kernel_size=(5, 5)),
10 | nn.ReLU(inplace=True),
11 | nn.MaxPool2d(kernel_size=(2, 2), stride=2),
12 | nn.Conv2d(6, 16, kernel_size=(5, 5)),
13 | nn.ReLU(inplace=True),
14 | nn.MaxPool2d(kernel_size=(2, 2), stride=2),
15 | nn.Conv2d(16, 120, kernel_size=(5, 5)),
16 | nn.ReLU(inplace=True),
17 | )
18 | self.fc = nn.Sequential(
19 | nn.Linear(120, 84),
20 | nn.ReLU(inplace=True),
21 | nn.Linear(84, num_classes)
22 | )
23 |
24 | def forward(self, img, return_features=False):
25 | features = self.features( img ).view(-1, 120)
26 | output = self.fc( features )
27 | if return_features:
28 | return output, features
29 | return output
30 |
31 |
32 | class LeNet5Half(nn.Module):
33 |
34 | def __init__(self, nc=1, num_classes=10):
35 | super(LeNet5Half, self).__init__()
36 | self.features = nn.Sequential(
37 | nn.Conv2d(1, 3, kernel_size=(5, 5)),
38 | nn.ReLU(inplace=True),
39 | nn.MaxPool2d(kernel_size=(2, 2), stride=2),
40 | nn.Conv2d(3, 8, kernel_size=(5, 5)),
41 | nn.ReLU(inplace=True),
42 | nn.MaxPool2d(kernel_size=(2, 2), stride=2),
43 | nn.Conv2d(8, 60, kernel_size=(5, 5)),
44 | nn.ReLU(inplace=True),
45 | )
46 | self.fc = nn.Sequential(
47 | nn.Linear(60, 42),
48 | nn.ReLU(inplace=True),
49 | nn.Linear(42, num_classes)
50 | )
51 |
52 | def forward(self, img, return_features=False):
53 | features = self.features( img ).view(-1, 60)
54 | output = self.fc( features )
55 | if return_features:
56 | return output, features
57 | return output
--------------------------------------------------------------------------------
/datafree/models/classifiers/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch import Tensor
3 | from torchvision.models.utils import load_state_dict_from_url
4 | from typing import Callable, Any, Optional, List
5 |
6 |
7 | __all__ = ['MobileNetV2', 'mobilenet_v2']
8 |
9 |
10 | model_urls = {
11 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
12 | }
13 |
14 |
15 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
16 | """
17 | This function is taken from the original tf repo.
18 | It ensures that all layers have a channel number that is divisible by 8
19 | It can be seen here:
20 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
21 | :param v:
22 | :param divisor:
23 | :param min_value:
24 | :return:
25 | """
26 | if min_value is None:
27 | min_value = divisor
28 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
29 | # Make sure that round down does not go down by more than 10%.
30 | if new_v < 0.9 * v:
31 | new_v += divisor
32 | return new_v
33 |
34 |
35 | class ConvBNActivation(nn.Sequential):
36 | def __init__(
37 | self,
38 | in_planes: int,
39 | out_planes: int,
40 | kernel_size: int = 3,
41 | stride: int = 1,
42 | groups: int = 1,
43 | norm_layer: Optional[Callable[..., nn.Module]] = None,
44 | activation_layer: Optional[Callable[..., nn.Module]] = None,
45 | ) -> None:
46 | padding = (kernel_size - 1) // 2
47 | if norm_layer is None:
48 | norm_layer = nn.BatchNorm2d
49 | if activation_layer is None:
50 | activation_layer = nn.ReLU6
51 | super(ConvBNReLU, self).__init__(
52 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
53 | norm_layer(out_planes),
54 | activation_layer(inplace=True)
55 | )
56 |
57 |
58 | # necessary for backwards compatibility
59 | ConvBNReLU = ConvBNActivation
60 |
61 |
62 | class InvertedResidual(nn.Module):
63 | def __init__(
64 | self,
65 | inp: int,
66 | oup: int,
67 | stride: int,
68 | expand_ratio: int,
69 | norm_layer: Optional[Callable[..., nn.Module]] = None
70 | ) -> None:
71 | super(InvertedResidual, self).__init__()
72 | self.stride = stride
73 | assert stride in [1, 2]
74 |
75 | if norm_layer is None:
76 | norm_layer = nn.BatchNorm2d
77 |
78 | hidden_dim = int(round(inp * expand_ratio))
79 | self.use_res_connect = self.stride == 1 and inp == oup
80 |
81 | layers: List[nn.Module] = []
82 | if expand_ratio != 1:
83 | # pw
84 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
85 | layers.extend([
86 | # dw
87 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
88 | # pw-linear
89 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
90 | norm_layer(oup),
91 | ])
92 | self.conv = nn.Sequential(*layers)
93 |
94 | def forward(self, x: Tensor) -> Tensor:
95 | if self.use_res_connect:
96 | return x + self.conv(x)
97 | else:
98 | return self.conv(x)
99 |
100 |
101 | class MobileNetV2(nn.Module):
102 | def __init__(
103 | self,
104 | num_classes: int = 1000,
105 | width_mult: float = 1.0,
106 | inverted_residual_setting: Optional[List[List[int]]] = None,
107 | round_nearest: int = 8,
108 | block: Optional[Callable[..., nn.Module]] = None,
109 | norm_layer: Optional[Callable[..., nn.Module]] = None
110 | ) -> None:
111 | """
112 | MobileNet V2 main class
113 | Args:
114 | num_classes (int): Number of classes
115 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
116 | inverted_residual_setting: Network structure
117 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
118 | Set to 1 to turn off rounding
119 | block: Module specifying inverted residual building block for mobilenet
120 | norm_layer: Module specifying the normalization layer to use
121 | """
122 | super(MobileNetV2, self).__init__()
123 |
124 | if block is None:
125 | block = InvertedResidual
126 |
127 | if norm_layer is None:
128 | norm_layer = nn.BatchNorm2d
129 |
130 | input_channel = 32
131 | last_channel = 1280
132 |
133 | if inverted_residual_setting is None:
134 | inverted_residual_setting = [
135 | # t, c, n, s
136 | [1, 16, 1, 1],
137 | [6, 24, 2, 2],
138 | [6, 32, 3, 2],
139 | [6, 64, 4, 2],
140 | [6, 96, 3, 1],
141 | [6, 160, 3, 2],
142 | [6, 320, 1, 1],
143 | ]
144 |
145 | # only check the first element, assuming user knows t,c,n,s are required
146 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
147 | raise ValueError("inverted_residual_setting should be non-empty "
148 | "or a 4-element list, got {}".format(inverted_residual_setting))
149 |
150 | # building first layer
151 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
152 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
153 | features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
154 | # building inverted residual blocks
155 | for t, c, n, s in inverted_residual_setting:
156 | output_channel = _make_divisible(c * width_mult, round_nearest)
157 | for i in range(n):
158 | stride = s if i == 0 else 1
159 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
160 | input_channel = output_channel
161 | # building last several layers
162 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
163 | # make it nn.Sequential
164 | self.features = nn.Sequential(*features)
165 |
166 | # building classifier
167 | self.classifier = nn.Sequential(
168 | nn.Dropout(0.2),
169 | nn.Linear(self.last_channel, num_classes),
170 | )
171 |
172 | # weight initialization
173 | for m in self.modules():
174 | if isinstance(m, nn.Conv2d):
175 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
176 | if m.bias is not None:
177 | nn.init.zeros_(m.bias)
178 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
179 | nn.init.ones_(m.weight)
180 | nn.init.zeros_(m.bias)
181 | elif isinstance(m, nn.Linear):
182 | nn.init.normal_(m.weight, 0, 0.01)
183 | nn.init.zeros_(m.bias)
184 |
185 | def _forward_impl(self, x: Tensor) -> Tensor:
186 | # This exists since TorchScript doesn't support inheritance, so the superclass method
187 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
188 | x = self.features(x)
189 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
190 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
191 | x = self.classifier(x)
192 | return x
193 |
194 | def forward(self, x: Tensor) -> Tensor:
195 | return self._forward_impl(x)
196 |
197 |
198 | def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2:
199 | """
200 | Constructs a MobileNetV2 architecture from
201 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
202 | Args:
203 | pretrained (bool): If True, returns a model pre-trained on ImageNet
204 | progress (bool): If True, displays a progress bar of the download to stderr
205 | """
206 | model = MobileNetV2(**kwargs)
207 | if pretrained:
208 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
209 | progress=progress)
210 | model.load_state_dict(state_dict)
211 | return model
--------------------------------------------------------------------------------
/datafree/models/classifiers/resnet.py:
--------------------------------------------------------------------------------
1 | # ResNet for CIFAR (32x32)
2 | # 2019.07.24-Changed output of forward function
3 | # Huawei Technologies Co., Ltd.
4 | # taken from https://github.com/huawei-noah/Data-Efficient-Model-Compression/blob/master/DAFL/resnet.py
5 | # for comparison with DAFL
6 |
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, in_planes, planes, stride=1):
16 | super(BasicBlock, self).__init__()
17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 |
22 | self.shortcut = nn.Sequential()
23 | if stride != 1 or in_planes != self.expansion*planes:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
26 | nn.BatchNorm2d(self.expansion*planes)
27 | )
28 |
29 | def forward(self, x):
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = self.bn2(self.conv2(out))
32 | out += self.shortcut(x)
33 | out = F.relu(out)
34 | return out
35 |
36 |
37 | class Bottleneck(nn.Module):
38 | expansion = 4
39 |
40 | def __init__(self, in_planes, planes, stride=1):
41 | super(Bottleneck, self).__init__()
42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
43 | self.bn1 = nn.BatchNorm2d(planes)
44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
45 | self.bn2 = nn.BatchNorm2d(planes)
46 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
48 |
49 | self.shortcut = nn.Sequential()
50 | if stride != 1 or in_planes != self.expansion*planes:
51 | self.shortcut = nn.Sequential(
52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
53 | nn.BatchNorm2d(self.expansion*planes)
54 | )
55 |
56 | def forward(self, x):
57 | out = F.relu(self.bn1(self.conv1(x)))
58 | out = F.relu(self.bn2(self.conv2(out)))
59 | out = self.bn3(self.conv3(out))
60 | out += self.shortcut(x)
61 | out = F.relu(out)
62 | return out
63 |
64 |
65 | class ResNet(nn.Module):
66 | def __init__(self, block, num_blocks, num_classes=10):
67 | super(ResNet, self).__init__()
68 | self.in_planes = 64
69 |
70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
71 | self.bn1 = nn.BatchNorm2d(64)
72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
76 | self.linear = nn.Linear(512*block.expansion, num_classes)
77 |
78 | def _make_layer(self, block, planes, num_blocks, stride):
79 | strides = [stride] + [1]*(num_blocks-1)
80 | layers = []
81 | for stride in strides:
82 | layers.append(block(self.in_planes, planes, stride))
83 | self.in_planes = planes * block.expansion
84 | return nn.Sequential(*layers)
85 |
86 | def forward(self, x, return_features=False):
87 | x = self.conv1(x)
88 | x = self.bn1(x)
89 | x = F.relu(x)
90 | x1 = self.layer1(x)
91 | x2 = self.layer2(x1)
92 | x3 = self.layer3(x2)
93 | x4 = self.layer4(x3)
94 | out = F.adaptive_avg_pool2d(x4, (1,1))
95 | feature = out.view(out.size(0), -1)
96 | out = self.linear(feature)
97 |
98 | if return_features:
99 | return out, feature, [x1, x2, x3, x4]
100 | return out
101 |
102 | def resnet18(num_classes=10):
103 | return ResNet(BasicBlock, [2,2,2,2], num_classes)
104 |
105 | def resnet34(num_classes=10):
106 | return ResNet(BasicBlock, [3,4,6,3], num_classes)
107 |
108 | def resnet50(num_classes=10):
109 | return ResNet(Bottleneck, [3,4,6,3], num_classes)
110 |
111 | def resnet101(num_classes=10):
112 | return ResNet(Bottleneck, [3,4,23,3], num_classes)
113 |
114 | def resnet152(num_classes=10):
115 | return ResNet(Bottleneck, [3,8,36,3], num_classes)
--------------------------------------------------------------------------------
/datafree/models/classifiers/resnet_in.py:
--------------------------------------------------------------------------------
1 | # ResNet for ImageNet (224x224)
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision.models.utils import load_state_dict_from_url
6 |
7 |
8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
10 | 'wide_resnet50_2', 'wide_resnet101_2']
11 |
12 |
13 | model_urls = {
14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
23 | }
24 |
25 |
26 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
27 | """3x3 convolution with padding"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
29 | padding=dilation, groups=groups, bias=False, dilation=dilation)
30 |
31 |
32 | def conv1x1(in_planes, out_planes, stride=1):
33 | """1x1 convolution"""
34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
35 |
36 |
37 | class BasicBlock(nn.Module):
38 | expansion = 1
39 |
40 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
41 | base_width=64, dilation=1, norm_layer=None):
42 | super(BasicBlock, self).__init__()
43 | if norm_layer is None:
44 | norm_layer = nn.BatchNorm2d
45 | if groups != 1 or base_width != 64:
46 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
47 | if dilation > 1:
48 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
49 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
50 | self.conv1 = conv3x3(inplanes, planes, stride)
51 | self.bn1 = norm_layer(planes)
52 | self.relu = nn.ReLU(inplace=True)
53 | self.conv2 = conv3x3(planes, planes)
54 | self.bn2 = norm_layer(planes)
55 | self.downsample = downsample
56 | self.stride = stride
57 |
58 | def forward(self, x):
59 | identity = x
60 |
61 | out = self.conv1(x)
62 | out = self.bn1(out)
63 | out = self.relu(out)
64 |
65 | out = self.conv2(out)
66 | out = self.bn2(out)
67 |
68 | if self.downsample is not None:
69 | identity = self.downsample(x)
70 |
71 | out += identity
72 | out = self.relu(out)
73 |
74 | return out
75 |
76 |
77 | class Bottleneck(nn.Module):
78 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
79 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
80 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
81 | # This variant is also known as ResNet V1.5 and improves accuracy according to
82 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
83 |
84 | expansion = 4
85 |
86 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
87 | base_width=64, dilation=1, norm_layer=None):
88 | super(Bottleneck, self).__init__()
89 | if norm_layer is None:
90 | norm_layer = nn.BatchNorm2d
91 | width = int(planes * (base_width / 64.)) * groups
92 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
93 | self.conv1 = conv1x1(inplanes, width)
94 | self.bn1 = norm_layer(width)
95 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
96 | self.bn2 = norm_layer(width)
97 | self.conv3 = conv1x1(width, planes * self.expansion)
98 | self.bn3 = norm_layer(planes * self.expansion)
99 | self.relu = nn.ReLU(inplace=True)
100 | self.downsample = downsample
101 | self.stride = stride
102 |
103 | def forward(self, x):
104 | identity = x
105 |
106 | out = self.conv1(x)
107 | out = self.bn1(out)
108 | out = self.relu(out)
109 |
110 | out = self.conv2(out)
111 | out = self.bn2(out)
112 | out = self.relu(out)
113 |
114 | out = self.conv3(out)
115 | out = self.bn3(out)
116 |
117 | if self.downsample is not None:
118 | identity = self.downsample(x)
119 |
120 | out += identity
121 | out = self.relu(out)
122 |
123 | return out
124 |
125 |
126 | class ResNet(nn.Module):
127 |
128 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
129 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
130 | norm_layer=None):
131 | super(ResNet, self).__init__()
132 | if norm_layer is None:
133 | norm_layer = nn.BatchNorm2d
134 | self._norm_layer = norm_layer
135 |
136 | self.inplanes = 64
137 | self.dilation = 1
138 | if replace_stride_with_dilation is None:
139 | # each element in the tuple indicates if we should replace
140 | # the 2x2 stride with a dilated convolution instead
141 | replace_stride_with_dilation = [False, False, False]
142 | if len(replace_stride_with_dilation) != 3:
143 | raise ValueError("replace_stride_with_dilation should be None "
144 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
145 | self.groups = groups
146 | self.base_width = width_per_group
147 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
148 | bias=False)
149 | self.bn1 = norm_layer(self.inplanes)
150 | self.relu = nn.ReLU(inplace=True)
151 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
152 | self.layer1 = self._make_layer(block, 64, layers[0])
153 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
154 | dilate=replace_stride_with_dilation[0])
155 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
156 | dilate=replace_stride_with_dilation[1])
157 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
158 | dilate=replace_stride_with_dilation[2])
159 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
160 | self.fc = nn.Linear(512 * block.expansion, num_classes)
161 |
162 | for m in self.modules():
163 | if isinstance(m, nn.Conv2d):
164 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
165 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
166 | nn.init.constant_(m.weight, 1)
167 | nn.init.constant_(m.bias, 0)
168 |
169 | # Zero-initialize the last BN in each residual branch,
170 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
171 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
172 | if zero_init_residual:
173 | for m in self.modules():
174 | if isinstance(m, Bottleneck):
175 | nn.init.constant_(m.bn3.weight, 0)
176 | elif isinstance(m, BasicBlock):
177 | nn.init.constant_(m.bn2.weight, 0)
178 |
179 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
180 | norm_layer = self._norm_layer
181 | downsample = None
182 | previous_dilation = self.dilation
183 | if dilate:
184 | self.dilation *= stride
185 | stride = 1
186 | if stride != 1 or self.inplanes != planes * block.expansion:
187 | downsample = nn.Sequential(
188 | conv1x1(self.inplanes, planes * block.expansion, stride),
189 | norm_layer(planes * block.expansion),
190 | )
191 |
192 | layers = []
193 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
194 | self.base_width, previous_dilation, norm_layer))
195 | self.inplanes = planes * block.expansion
196 | for _ in range(1, blocks):
197 | layers.append(block(self.inplanes, planes, groups=self.groups,
198 | base_width=self.base_width, dilation=self.dilation,
199 | norm_layer=norm_layer))
200 |
201 | return nn.Sequential(*layers)
202 |
203 | def _forward_impl(self, x, return_features):
204 | # See note [TorchScript super()]
205 | x = self.conv1(x)
206 | x = self.bn1(x)
207 | x = self.relu(x)
208 | x = self.maxpool(x)
209 |
210 | x1 = self.layer1(x)
211 | x2 = self.layer2(x1)
212 | x3 = self.layer3(x2)
213 | x4 = self.layer4(x3)
214 |
215 | x = self.avgpool(x4)
216 | feat = torch.flatten(x, 1)
217 | x = self.fc(feat)
218 | if return_features:
219 | return x, feat, [x1, x2, x3, x4]
220 | return x
221 |
222 | def forward(self, x, return_features=False):
223 | return self._forward_impl(x, return_features=return_features)
224 |
225 |
226 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
227 | model = ResNet(block, layers, **kwargs)
228 | if pretrained:
229 | state_dict = load_state_dict_from_url(model_urls[arch],
230 | progress=progress)
231 | print('load from', model_urls[arch])
232 | model.load_state_dict(state_dict)
233 | return model
234 |
235 |
236 | def resnet18(pretrained=False, progress=True, **kwargs):
237 | r"""ResNet-18 model from
238 | `"Deep Residual Learning for Image Recognition" `_
239 | Args:
240 | pretrained (bool): If True, returns a model pre-trained on ImageNet
241 | progress (bool): If True, displays a progress bar of the download to stderr
242 | """
243 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
244 | **kwargs)
245 |
246 |
247 | def resnet34(pretrained=False, progress=True, **kwargs):
248 | r"""ResNet-34 model from
249 | `"Deep Residual Learning for Image Recognition" `_
250 | Args:
251 | pretrained (bool): If True, returns a model pre-trained on ImageNet
252 | progress (bool): If True, displays a progress bar of the download to stderr
253 | """
254 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
255 | **kwargs)
256 |
257 |
258 | def resnet50(pretrained=False, progress=True, **kwargs):
259 | r"""ResNet-50 model from
260 | `"Deep Residual Learning for Image Recognition" `_
261 | Args:
262 | pretrained (bool): If True, returns a model pre-trained on ImageNet
263 | progress (bool): If True, displays a progress bar of the download to stderr
264 | """
265 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
266 | **kwargs)
267 |
268 |
269 | def resnet101(pretrained=False, progress=True, **kwargs):
270 | r"""ResNet-101 model from
271 | `"Deep Residual Learning for Image Recognition" `_
272 | Args:
273 | pretrained (bool): If True, returns a model pre-trained on ImageNet
274 | progress (bool): If True, displays a progress bar of the download to stderr
275 | """
276 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
277 | **kwargs)
278 |
279 |
280 | def resnet152(pretrained=False, progress=True, **kwargs):
281 | r"""ResNet-152 model from
282 | `"Deep Residual Learning for Image Recognition" `_
283 | Args:
284 | pretrained (bool): If True, returns a model pre-trained on ImageNet
285 | progress (bool): If True, displays a progress bar of the download to stderr
286 | """
287 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
288 | **kwargs)
289 |
290 |
291 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
292 | r"""ResNeXt-50 32x4d model from
293 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
294 | Args:
295 | pretrained (bool): If True, returns a model pre-trained on ImageNet
296 | progress (bool): If True, displays a progress bar of the download to stderr
297 | """
298 | kwargs['groups'] = 32
299 | kwargs['width_per_group'] = 4
300 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
301 | pretrained, progress, **kwargs)
302 |
303 |
304 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
305 | r"""ResNeXt-101 32x8d model from
306 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
307 | Args:
308 | pretrained (bool): If True, returns a model pre-trained on ImageNet
309 | progress (bool): If True, displays a progress bar of the download to stderr
310 | """
311 | kwargs['groups'] = 32
312 | kwargs['width_per_group'] = 8
313 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
314 | pretrained, progress, **kwargs)
315 |
316 |
317 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
318 | r"""Wide ResNet-50-2 model from
319 | `"Wide Residual Networks" `_
320 | The model is the same as ResNet except for the bottleneck number of channels
321 | which is twice larger in every block. The number of channels in outer 1x1
322 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
323 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
324 | Args:
325 | pretrained (bool): If True, returns a model pre-trained on ImageNet
326 | progress (bool): If True, displays a progress bar of the download to stderr
327 | """
328 | kwargs['width_per_group'] = 64 * 2
329 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
330 | pretrained, progress, **kwargs)
331 |
332 |
333 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
334 | r"""Wide ResNet-101-2 model from
335 | `"Wide Residual Networks" `_
336 | The model is the same as ResNet except for the bottleneck number of channels
337 | which is twice larger in every block. The number of channels in outer 1x1
338 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
339 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
340 | Args:
341 | pretrained (bool): If True, returns a model pre-trained on ImageNet
342 | progress (bool): If True, displays a progress bar of the download to stderr
343 | """
344 | kwargs['width_per_group'] = 64 * 2
345 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
346 | pretrained, progress, **kwargs)
--------------------------------------------------------------------------------
/datafree/models/classifiers/resnet_tiny.py:
--------------------------------------------------------------------------------
1 | # Tiny ResNet for CIFAR (32x32)
2 |
3 | from __future__ import absolute_import
4 |
5 | '''Resnet for cifar dataset.
6 | https://github.com/HobbitLong/RepDistiller/blob/master/models/resnet.py
7 | Ported form
8 | https://github.com/facebook/fb.resnet.torch
9 | and
10 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
11 | (c) YANG, Wei
12 | '''
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import math
16 |
17 |
18 | __all__ = ['resnet']
19 |
20 |
21 | def conv3x3(in_planes, out_planes, stride=1):
22 | """3x3 convolution with padding"""
23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
24 | padding=1, bias=False)
25 |
26 |
27 | class BasicBlock(nn.Module):
28 | expansion = 1
29 |
30 | def __init__(self, inplanes, planes, stride=1, downsample=None):
31 | super(BasicBlock, self).__init__()
32 | self.conv1 = conv3x3(inplanes, planes, stride)
33 | self.bn1 = nn.BatchNorm2d(planes)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.conv2 = conv3x3(planes, planes)
36 | self.bn2 = nn.BatchNorm2d(planes)
37 | self.downsample = downsample
38 | self.stride = stride
39 |
40 | def forward(self, x):
41 | residual = x
42 |
43 | out = self.conv1(x)
44 | out = self.bn1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.bn2(out)
49 |
50 | if self.downsample is not None:
51 | residual = self.downsample(x)
52 |
53 | out += residual
54 | preact = out
55 | out = F.relu(out)
56 | return out
57 |
58 |
59 | class Bottleneck(nn.Module):
60 | expansion = 4
61 |
62 | def __init__(self, inplanes, planes, stride=1, downsample=None):
63 | super(Bottleneck, self).__init__()
64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65 | self.bn1 = nn.BatchNorm2d(planes)
66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67 | padding=1, bias=False)
68 | self.bn2 = nn.BatchNorm2d(planes)
69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70 | self.bn3 = nn.BatchNorm2d(planes * 4)
71 | self.relu = nn.ReLU(inplace=True)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | preact = out
94 | out = F.relu(out)
95 | return out
96 |
97 |
98 | class ResNet(nn.Module):
99 |
100 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10):
101 | super(ResNet, self).__init__()
102 | if block_name.lower() == 'basicblock':
103 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
104 | n = (depth - 2) // 6
105 | block = BasicBlock
106 | elif block_name.lower() == 'bottleneck':
107 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
108 | n = (depth - 2) // 9
109 | block = Bottleneck
110 | else:
111 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
112 |
113 | self.inplanes = num_filters[0]
114 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1,
115 | bias=False)
116 | self.bn1 = nn.BatchNorm2d(num_filters[0])
117 | self.relu = nn.ReLU(inplace=True)
118 | self.layer1 = self._make_layer(block, num_filters[1], n)
119 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2)
120 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2)
121 | self.avgpool = nn.AvgPool2d(8)
122 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes)
123 |
124 | for m in self.modules():
125 | if isinstance(m, nn.Conv2d):
126 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
127 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
128 | nn.init.constant_(m.weight, 1)
129 | nn.init.constant_(m.bias, 0)
130 |
131 | def _make_layer(self, block, planes, blocks, stride=1):
132 | downsample = None
133 | if stride != 1 or self.inplanes != planes * block.expansion:
134 | downsample = nn.Sequential(
135 | nn.Conv2d(self.inplanes, planes * block.expansion,
136 | kernel_size=1, stride=stride, bias=False),
137 | nn.BatchNorm2d(planes * block.expansion),
138 | )
139 |
140 | layers = list([])
141 | layers.append(block(self.inplanes, planes, stride, downsample))
142 | self.inplanes = planes * block.expansion
143 | for i in range(1, blocks):
144 | layers.append(block(self.inplanes, planes))
145 |
146 | return nn.Sequential(*layers)
147 |
148 | def forward(self, x, return_features=False):
149 | x = self.conv1(x)
150 | x = self.bn1(x)
151 | x = self.relu(x) # 32x32
152 | x1 = self.layer1(x)
153 | x2 = self.layer2(x1)
154 | x3 = self.layer3(x2)
155 | x = self.avgpool(x3)
156 | features = x.view(x.size(0), -1)
157 | x = self.fc(features)
158 |
159 | if return_features:
160 | return x, features, [x1, x2, x3]
161 | return x
162 |
163 |
164 | def resnet8(num_classes):
165 | return ResNet(8, [16, 16, 32, 64], 'basicblock', num_classes=num_classes)
166 |
167 |
168 | def resnet14(num_classes):
169 | return ResNet(14, [16, 16, 32, 64], 'basicblock', num_classes=num_classes)
170 |
171 |
172 | def resnet20(num_classes):
173 | return ResNet(20, [16, 16, 32, 64], 'basicblock', num_classes=num_classes)
174 |
175 |
176 | def resnet32(num_classes):
177 | return ResNet(32, [16, 16, 32, 64], 'basicblock', num_classes=num_classes)
178 |
179 |
180 | def resnet44(num_classes):
181 | return ResNet(44, [16, 16, 32, 64], 'basicblock', num_classes=num_classes)
182 |
183 |
184 | def resnet56(num_classes):
185 | return ResNet(56, [16, 16, 32, 64], 'basicblock', num_classes=num_classes)
186 |
187 |
188 | def resnet110(num_classes):
189 | return ResNet(110, [16, 16, 32, 64], 'basicblock', num_classes=num_classes)
190 |
191 |
192 | def resnet8x4(num_classes):
193 | return ResNet(8, [32, 64, 128, 256], 'basicblock', num_classes=num_classes)
194 |
195 |
196 | def resnet32x4(num_classes):
197 | return ResNet(32, [32, 64, 128, 256], 'basicblock', num_classes=num_classes)
--------------------------------------------------------------------------------
/datafree/models/classifiers/shufflenetv2.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNetV2 in PyTorch.
2 | https://github.com/HobbitLong/RepDistiller/blob/34557d2728/models/ShuffleNetv2.py
3 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class ShuffleBlock(nn.Module):
11 | def __init__(self, groups=2):
12 | super(ShuffleBlock, self).__init__()
13 | self.groups = groups
14 |
15 | def forward(self, x):
16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
17 | N, C, H, W = x.size()
18 | g = self.groups
19 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)
20 |
21 |
22 | class SplitBlock(nn.Module):
23 | def __init__(self, ratio):
24 | super(SplitBlock, self).__init__()
25 | self.ratio = ratio
26 |
27 | def forward(self, x):
28 | c = int(x.size(1) * self.ratio)
29 | return x[:, :c, :, :], x[:, c:, :, :]
30 |
31 |
32 | class BasicBlock(nn.Module):
33 | def __init__(self, in_channels, split_ratio=0.5, is_last=False):
34 | super(BasicBlock, self).__init__()
35 | self.is_last = is_last
36 | self.split = SplitBlock(split_ratio)
37 | in_channels = int(in_channels * split_ratio)
38 | self.conv1 = nn.Conv2d(in_channels, in_channels,
39 | kernel_size=1, bias=False)
40 | self.bn1 = nn.BatchNorm2d(in_channels)
41 | self.conv2 = nn.Conv2d(in_channels, in_channels,
42 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False)
43 | self.bn2 = nn.BatchNorm2d(in_channels)
44 | self.conv3 = nn.Conv2d(in_channels, in_channels,
45 | kernel_size=1, bias=False)
46 | self.bn3 = nn.BatchNorm2d(in_channels)
47 | self.shuffle = ShuffleBlock()
48 |
49 | def forward(self, x):
50 | x1, x2 = self.split(x)
51 | out = F.relu(self.bn1(self.conv1(x2)))
52 | out = self.bn2(self.conv2(out))
53 | preact = self.bn3(self.conv3(out))
54 | out = F.relu(preact)
55 | # out = F.relu(self.bn3(self.conv3(out)))
56 | preact = torch.cat([x1, preact], 1)
57 | out = torch.cat([x1, out], 1)
58 | out = self.shuffle(out)
59 | if self.is_last:
60 | return out, preact
61 | else:
62 | return out
63 |
64 |
65 | class DownBlock(nn.Module):
66 | def __init__(self, in_channels, out_channels):
67 | super(DownBlock, self).__init__()
68 | mid_channels = out_channels // 2
69 | # left
70 | self.conv1 = nn.Conv2d(in_channels, in_channels,
71 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False)
72 | self.bn1 = nn.BatchNorm2d(in_channels)
73 | self.conv2 = nn.Conv2d(in_channels, mid_channels,
74 | kernel_size=1, bias=False)
75 | self.bn2 = nn.BatchNorm2d(mid_channels)
76 | # right
77 | self.conv3 = nn.Conv2d(in_channels, mid_channels,
78 | kernel_size=1, bias=False)
79 | self.bn3 = nn.BatchNorm2d(mid_channels)
80 | self.conv4 = nn.Conv2d(mid_channels, mid_channels,
81 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False)
82 | self.bn4 = nn.BatchNorm2d(mid_channels)
83 | self.conv5 = nn.Conv2d(mid_channels, mid_channels,
84 | kernel_size=1, bias=False)
85 | self.bn5 = nn.BatchNorm2d(mid_channels)
86 |
87 | self.shuffle = ShuffleBlock()
88 |
89 | def forward(self, x):
90 | # left
91 | out1 = self.bn1(self.conv1(x))
92 | out1 = F.relu(self.bn2(self.conv2(out1)))
93 | # right
94 | out2 = F.relu(self.bn3(self.conv3(x)))
95 | out2 = self.bn4(self.conv4(out2))
96 | out2 = F.relu(self.bn5(self.conv5(out2)))
97 | # concat
98 | out = torch.cat([out1, out2], 1)
99 | out = self.shuffle(out)
100 | return out
101 |
102 |
103 | class ShuffleNetV2(nn.Module):
104 | def __init__(self, net_size, num_classes=10):
105 | super(ShuffleNetV2, self).__init__()
106 | out_channels = configs[net_size]['out_channels']
107 | num_blocks = configs[net_size]['num_blocks']
108 |
109 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3,
110 | # stride=1, padding=1, bias=False)
111 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
112 | self.bn1 = nn.BatchNorm2d(24)
113 | self.in_channels = 24
114 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0])
115 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1])
116 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2])
117 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3],
118 | kernel_size=1, stride=1, padding=0, bias=False)
119 | self.bn2 = nn.BatchNorm2d(out_channels[3])
120 | self.linear = nn.Linear(out_channels[3], num_classes)
121 |
122 | def _make_layer(self, out_channels, num_blocks):
123 | layers = [DownBlock(self.in_channels, out_channels)]
124 | for i in range(num_blocks):
125 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1)))
126 | self.in_channels = out_channels
127 | return nn.Sequential(*layers)
128 |
129 | def get_feat_modules(self):
130 | feat_m = nn.ModuleList([])
131 | feat_m.append(self.conv1)
132 | feat_m.append(self.bn1)
133 | feat_m.append(self.layer1)
134 | feat_m.append(self.layer2)
135 | feat_m.append(self.layer3)
136 | return feat_m
137 |
138 | def get_bn_before_relu(self):
139 | raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher')
140 |
141 | def forward(self, x, return_features=False):
142 | out = F.relu(self.bn1(self.conv1(x)))
143 | out, f1_pre = self.layer1(out)
144 | out, f2_pre = self.layer2(out)
145 | out, f3_pre = self.layer3(out)
146 | out = F.relu(self.bn2(self.conv2(out)))
147 | out = F.avg_pool2d(out, 4)
148 | features = out.view(out.size(0), -1)
149 | out = self.linear(features)
150 | if return_features:
151 | return out, features
152 | else:
153 | return out
154 |
155 | configs = {
156 | 0.2: {
157 | 'out_channels': (40, 80, 160, 512),
158 | 'num_blocks': (3, 3, 3)
159 | },
160 |
161 | 0.3: {
162 | 'out_channels': (40, 80, 160, 512),
163 | 'num_blocks': (3, 7, 3)
164 | },
165 |
166 | 0.5: {
167 | 'out_channels': (48, 96, 192, 1024),
168 | 'num_blocks': (3, 7, 3)
169 | },
170 |
171 | 1: {
172 | 'out_channels': (116, 232, 464, 1024),
173 | 'num_blocks': (3, 7, 3)
174 | },
175 | 1.5: {
176 | 'out_channels': (176, 352, 704, 1024),
177 | 'num_blocks': (3, 7, 3)
178 | },
179 | 2: {
180 | 'out_channels': (224, 488, 976, 2048),
181 | 'num_blocks': (3, 7, 3)
182 | }
183 | }
184 |
185 |
186 | def shuffle_v2(num_classes):
187 | model = ShuffleNetV2(net_size=1, num_classes=num_classes)
188 | return model
--------------------------------------------------------------------------------
/datafree/models/classifiers/vgg.py:
--------------------------------------------------------------------------------
1 | """https://github.com/HobbitLong/RepDistiller/blob/master/models/vgg.py
2 | """
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 |
7 |
8 | __all__ = [
9 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
10 | 'vgg19_bn', 'vgg19',
11 | ]
12 |
13 |
14 | model_urls = {
15 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
16 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
17 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
18 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
19 | }
20 |
21 |
22 | class VGG(nn.Module):
23 |
24 | def __init__(self, cfg, batch_norm=False, num_classes=1000):
25 | super(VGG, self).__init__()
26 | self.block0 = self._make_layers(cfg[0], batch_norm, 3)
27 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1])
28 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1])
29 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1])
30 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1])
31 |
32 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
33 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
34 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
35 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
36 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1))
37 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
38 |
39 | self.classifier = nn.Linear(512, num_classes)
40 | self._initialize_weights()
41 |
42 | def get_feat_modules(self):
43 | feat_m = nn.ModuleList([])
44 | feat_m.append(self.block0)
45 | feat_m.append(self.pool0)
46 | feat_m.append(self.block1)
47 | feat_m.append(self.pool1)
48 | feat_m.append(self.block2)
49 | feat_m.append(self.pool2)
50 | feat_m.append(self.block3)
51 | feat_m.append(self.pool3)
52 | feat_m.append(self.block4)
53 | feat_m.append(self.pool4)
54 | return feat_m
55 |
56 | def get_bn_before_relu(self):
57 | bn1 = self.block1[-1]
58 | bn2 = self.block2[-1]
59 | bn3 = self.block3[-1]
60 | bn4 = self.block4[-1]
61 | return [bn1, bn2, bn3, bn4]
62 |
63 | def forward(self, x, return_features=False):
64 | h = x.shape[2]
65 | x = F.relu(self.block0(x))
66 | x = self.pool0(x)
67 | x = self.block1(x)
68 | x = F.relu(x)
69 | x = self.pool1(x)
70 | x = self.block2(x)
71 | x = F.relu(x)
72 | x = self.pool2(x)
73 | x = self.block3(x)
74 | x = F.relu(x)
75 | if h == 64:
76 | x = self.pool3(x)
77 | x = self.block4(x)
78 | x = F.relu(x)
79 | x = self.pool4(x)
80 | features = x.view(x.size(0), -1)
81 | x = self.classifier(features)
82 | if return_features:
83 | return x, features
84 | else:
85 | return x
86 |
87 | @staticmethod
88 | def _make_layers(cfg, batch_norm=False, in_channels=3):
89 | layers = []
90 | for v in cfg:
91 | if v == 'M':
92 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
93 | else:
94 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
95 | if batch_norm:
96 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
97 | else:
98 | layers += [conv2d, nn.ReLU(inplace=True)]
99 | in_channels = v
100 | layers = layers[:-1]
101 | return nn.Sequential(*layers)
102 |
103 | def _initialize_weights(self):
104 | for m in self.modules():
105 | if isinstance(m, nn.Conv2d):
106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
107 | m.weight.data.normal_(0, math.sqrt(2. / n))
108 | if m.bias is not None:
109 | m.bias.data.zero_()
110 | elif isinstance(m, nn.BatchNorm2d):
111 | m.weight.data.fill_(1)
112 | m.bias.data.zero_()
113 | elif isinstance(m, nn.Linear):
114 | n = m.weight.size(1)
115 | m.weight.data.normal_(0, 0.01)
116 | m.bias.data.zero_()
117 |
118 |
119 | cfg = {
120 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]],
121 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]],
122 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]],
123 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]],
124 | 'S': [[64], [128], [256], [512], [512]],
125 | }
126 |
127 |
128 | def vgg8(**kwargs):
129 | """VGG 8-layer model (configuration "S")
130 | Args:
131 | pretrained (bool): If True, returns a model pre-trained on ImageNet
132 | """
133 | model = VGG(cfg['S'], **kwargs)
134 | return model
135 |
136 |
137 | def vgg8_bn(**kwargs):
138 | """VGG 8-layer model (configuration "S")
139 | Args:
140 | pretrained (bool): If True, returns a model pre-trained on ImageNet
141 | """
142 | model = VGG(cfg['S'], batch_norm=True, **kwargs)
143 | return model
144 |
145 |
146 | def vgg11(**kwargs):
147 | """VGG 11-layer model (configuration "A")
148 | Args:
149 | pretrained (bool): If True, returns a model pre-trained on ImageNet
150 | """
151 | model = VGG(cfg['A'], **kwargs)
152 | return model
153 |
154 |
155 | def vgg11_bn(**kwargs):
156 | """VGG 11-layer model (configuration "A") with batch normalization"""
157 | model = VGG(cfg['A'], batch_norm=True, **kwargs)
158 | return model
159 |
160 |
161 | def vgg13(**kwargs):
162 | """VGG 13-layer model (configuration "B")
163 | Args:
164 | pretrained (bool): If True, returns a model pre-trained on ImageNet
165 | """
166 | model = VGG(cfg['B'], **kwargs)
167 | return model
168 |
169 |
170 | def vgg13_bn(**kwargs):
171 | """VGG 13-layer model (configuration "B") with batch normalization"""
172 | model = VGG(cfg['B'], batch_norm=True, **kwargs)
173 | return model
174 |
175 |
176 | def vgg16(**kwargs):
177 | """VGG 16-layer model (configuration "D")
178 | Args:
179 | pretrained (bool): If True, returns a model pre-trained on ImageNet
180 | """
181 | model = VGG(cfg['D'], **kwargs)
182 | return model
183 |
184 |
185 | def vgg16_bn(**kwargs):
186 | """VGG 16-layer model (configuration "D") with batch normalization"""
187 | model = VGG(cfg['D'], batch_norm=True, **kwargs)
188 | return model
189 |
190 |
191 | def vgg19(**kwargs):
192 | """VGG 19-layer model (configuration "E")
193 | Args:
194 | pretrained (bool): If True, returns a model pre-trained on ImageNet
195 | """
196 | model = VGG(cfg['E'], **kwargs)
197 | return model
198 |
199 |
200 | def vgg19_bn(**kwargs):
201 | """VGG 19-layer model (configuration 'E') with batch normalization"""
202 | model = VGG(cfg['E'], batch_norm=True, **kwargs)
203 | return model
204 |
205 |
206 | if __name__ == '__main__':
207 | import torch
208 |
209 | x = torch.randn(2, 3, 32, 32)
210 | net = vgg19_bn(num_classes=100)
211 | feats, logit = net(x, is_feat=True, preact=True)
212 |
213 | for f in feats:
214 | print(f.shape, f.min().item())
215 | print(logit.shape)
216 |
217 | for m in net.get_bn_before_relu():
218 | if isinstance(m, nn.BatchNorm2d):
219 | print('pass')
220 | else:
221 | print('warning')
--------------------------------------------------------------------------------
/datafree/models/classifiers/wresnet.py:
--------------------------------------------------------------------------------
1 | '''https://github.com/polo5/ZeroShotKnowledgeTransfer/blob/master/models/wresnet.py
2 | '''
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | """
10 | Original Author: Wei Yang
11 | """
12 |
13 | __all__ = ['wrn']
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | def __init__(self, in_planes, out_planes, stride, dropout_rate=0.0):
18 | super(BasicBlock, self).__init__()
19 | self.bn1 = nn.BatchNorm2d(in_planes)
20 | self.relu1 = nn.ReLU(inplace=True)
21 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22 | padding=1, bias=False)
23 | self.bn2 = nn.BatchNorm2d(out_planes)
24 | self.relu2 = nn.ReLU(inplace=True)
25 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
26 | padding=1, bias=False)
27 | self.dropout = nn.Dropout( dropout_rate )
28 | self.equalInOut = (in_planes == out_planes)
29 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
30 | padding=0, bias=False) or None
31 |
32 | def forward(self, x):
33 | if not self.equalInOut:
34 | x = self.relu1(self.bn1(x))
35 | else:
36 | out = self.relu1(self.bn1(x))
37 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
38 | out = self.dropout(out)
39 | out = self.conv2(out)
40 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
41 |
42 |
43 | class NetworkBlock(nn.Module):
44 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropout_rate=0.0):
45 | super(NetworkBlock, self).__init__()
46 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropout_rate)
47 |
48 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropout_rate):
49 | layers = []
50 | for i in range(nb_layers):
51 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropout_rate))
52 | return nn.Sequential(*layers)
53 |
54 | def forward(self, x):
55 | return self.layer(x)
56 |
57 |
58 | class WideResNet(nn.Module):
59 | def __init__(self, depth, num_classes, widen_factor=1, dropout_rate=0.0):
60 | super(WideResNet, self).__init__()
61 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
62 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
63 | n = (depth - 4) // 6
64 | block = BasicBlock
65 | # 1st conv before any network block
66 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
67 | padding=1, bias=False)
68 | # 1st block
69 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropout_rate)
70 | # 2nd block
71 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropout_rate)
72 | # 3rd block
73 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropout_rate)
74 | # global average pooling and classifier
75 | self.bn1 = nn.BatchNorm2d(nChannels[3])
76 | self.relu = nn.ReLU(inplace=True)
77 | self.fc = nn.Linear(nChannels[3], num_classes)
78 | self.nChannels = nChannels[3]
79 |
80 | for m in self.modules():
81 | if isinstance(m, nn.Conv2d):
82 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
83 | m.weight.data.normal_(0, math.sqrt(2. / n))
84 | elif isinstance(m, nn.BatchNorm2d):
85 | m.weight.data.fill_(1)
86 | m.bias.data.zero_()
87 | elif isinstance(m, nn.Linear):
88 | m.bias.data.zero_()
89 |
90 | def forward(self, x, return_features=False):
91 | out = self.conv1(x)
92 | x1 = self.block1(out)
93 | x2 = self.block2(x1)
94 | x3 = self.block3(x2)
95 | out = self.relu(self.bn1(x3))
96 | out = F.adaptive_avg_pool2d(out, (1,1))
97 | features = out.view(-1, self.nChannels)
98 | out = self.fc(features)
99 |
100 | if return_features:
101 | return out, features, [x1, x2, x3]
102 | else:
103 | return out
104 |
105 | def wrn_16_1(num_classes, dropout_rate=0):
106 | return WideResNet(depth=16, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate)
107 |
108 | def wrn_16_2(num_classes, dropout_rate=0):
109 | return WideResNet(depth=16, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate)
110 |
111 | def wrn_40_1(num_classes, dropout_rate=0):
112 | return WideResNet(depth=40, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate)
113 |
114 | def wrn_40_2(num_classes, dropout_rate=0):
115 | return WideResNet(depth=40, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate)
--------------------------------------------------------------------------------
/datafree/models/deeplab/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling import *
2 | from ._deeplab import convert_to_separable_conv
--------------------------------------------------------------------------------
/datafree/models/deeplab/_deeplab.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from .utils import _SimpleSegmentationModel
6 |
7 |
8 | __all__ = ["DeepLabV3"]
9 |
10 |
11 | class DeepLabV3(_SimpleSegmentationModel):
12 | """
13 | Implements DeepLabV3 model from
14 | `"Rethinking Atrous Convolution for Semantic Image Segmentation"
15 | `_.
16 |
17 | Arguments:
18 | backbone (nn.Module): the network used to compute the features for the model.
19 | The backbone should return an OrderedDict[Tensor], with the key being
20 | "out" for the last feature map used, and "aux" if an auxiliary classifier
21 | is used.
22 | classifier (nn.Module): module that takes the "out" element returned from
23 | the backbone and returns a dense prediction.
24 | aux_classifier (nn.Module, optional): auxiliary classifier used during training
25 | """
26 | pass
27 |
28 | class DeepLabHeadV3Plus(nn.Module):
29 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
30 | super(DeepLabHeadV3Plus, self).__init__()
31 | self.project = nn.Sequential(
32 | nn.Conv2d(low_level_channels, 48, 1, bias=False),
33 | nn.BatchNorm2d(48),
34 | nn.ReLU(inplace=True),
35 | )
36 |
37 | self.aspp = ASPP(in_channels, aspp_dilate)
38 |
39 | self.classifier = nn.Sequential(
40 | nn.Conv2d(304, 256, 3, padding=1, bias=False),
41 | nn.BatchNorm2d(256),
42 | nn.ReLU(inplace=True),
43 | nn.Conv2d(256, num_classes, 1)
44 | )
45 | self._init_weight()
46 |
47 | def forward(self, feature):
48 | low_level_feature = self.project( feature['low_level'] )
49 | output_feature = self.aspp(feature['out'])
50 | output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
51 | return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) )
52 |
53 | def _init_weight(self):
54 | for m in self.modules():
55 | if isinstance(m, nn.Conv2d):
56 | nn.init.kaiming_normal_(m.weight)
57 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
58 | nn.init.constant_(m.weight, 1)
59 | nn.init.constant_(m.bias, 0)
60 |
61 | class DeepLabHead(nn.Module):
62 | def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
63 | super(DeepLabHead, self).__init__()
64 |
65 | self.classifier = nn.Sequential(
66 | ASPP(in_channels, aspp_dilate),
67 | nn.Conv2d(256, 256, 3, padding=1, bias=False),
68 | nn.BatchNorm2d(256),
69 | nn.ReLU(inplace=True),
70 | nn.Conv2d(256, num_classes, 1)
71 | )
72 | self._init_weight()
73 |
74 | def forward(self, feature):
75 | return self.classifier( feature['out'] )
76 |
77 | def _init_weight(self):
78 | for m in self.modules():
79 | if isinstance(m, nn.Conv2d):
80 | nn.init.kaiming_normal_(m.weight)
81 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
82 | nn.init.constant_(m.weight, 1)
83 | nn.init.constant_(m.bias, 0)
84 |
85 | class AtrousSeparableConvolution(nn.Module):
86 | """ Atrous Separable Convolution
87 | """
88 | def __init__(self, in_channels, out_channels, kernel_size,
89 | stride=1, padding=0, dilation=1, bias=True):
90 | super(AtrousSeparableConvolution, self).__init__()
91 | self.body = nn.Sequential(
92 | # Separable Conv
93 | nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ),
94 | # PointWise Conv
95 | nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
96 | )
97 |
98 | self._init_weight()
99 |
100 | def forward(self, x):
101 | return self.body(x)
102 |
103 | def _init_weight(self):
104 | for m in self.modules():
105 | if isinstance(m, nn.Conv2d):
106 | nn.init.kaiming_normal_(m.weight)
107 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
108 | nn.init.constant_(m.weight, 1)
109 | nn.init.constant_(m.bias, 0)
110 |
111 | class ASPPConv(nn.Sequential):
112 | def __init__(self, in_channels, out_channels, dilation):
113 | modules = [
114 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
115 | nn.BatchNorm2d(out_channels),
116 | nn.ReLU(inplace=True)
117 | ]
118 | super(ASPPConv, self).__init__(*modules)
119 |
120 | class ASPPPooling(nn.Sequential):
121 | def __init__(self, in_channels, out_channels):
122 | super(ASPPPooling, self).__init__(
123 | nn.AdaptiveAvgPool2d(1),
124 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
125 | nn.BatchNorm2d(out_channels),
126 | nn.ReLU(inplace=True))
127 |
128 | def forward(self, x):
129 | size = x.shape[-2:]
130 | x = super(ASPPPooling, self).forward(x)
131 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
132 |
133 | class ASPP(nn.Module):
134 | def __init__(self, in_channels, atrous_rates):
135 | super(ASPP, self).__init__()
136 | out_channels = 256
137 | modules = []
138 | modules.append(nn.Sequential(
139 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
140 | nn.BatchNorm2d(out_channels),
141 | nn.ReLU(inplace=True)))
142 |
143 | rate1, rate2, rate3 = tuple(atrous_rates)
144 | modules.append(ASPPConv(in_channels, out_channels, rate1))
145 | modules.append(ASPPConv(in_channels, out_channels, rate2))
146 | modules.append(ASPPConv(in_channels, out_channels, rate3))
147 | modules.append(ASPPPooling(in_channels, out_channels))
148 |
149 | self.convs = nn.ModuleList(modules)
150 |
151 | self.project = nn.Sequential(
152 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
153 | nn.BatchNorm2d(out_channels),
154 | nn.ReLU(inplace=True),
155 | nn.Dropout(0.1),)
156 |
157 | def forward(self, x):
158 | res = []
159 | for conv in self.convs:
160 | res.append(conv(x))
161 | res = torch.cat(res, dim=1)
162 | return self.project(res)
163 |
164 |
165 |
166 | def convert_to_separable_conv(module):
167 | new_module = module
168 | if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1:
169 | new_module = AtrousSeparableConvolution(module.in_channels,
170 | module.out_channels,
171 | module.kernel_size,
172 | module.stride,
173 | module.padding,
174 | module.dilation,
175 | module.bias)
176 | for name, child in module.named_children():
177 | new_module.add_module(name, convert_to_separable_conv(child))
178 | return new_module
--------------------------------------------------------------------------------
/datafree/models/deeplab/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from . import resnet
2 | from . import mobilenetv2
3 |
--------------------------------------------------------------------------------
/datafree/models/deeplab/backbone/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torchvision.models.utils import load_state_dict_from_url
3 | import torch.nn.functional as F
4 |
5 | __all__ = ['MobileNetV2', 'mobilenet_v2']
6 |
7 |
8 | model_urls = {
9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
10 | }
11 |
12 |
13 | def _make_divisible(v, divisor, min_value=None):
14 | """
15 | This function is taken from the original tf repo.
16 | It ensures that all layers have a channel number that is divisible by 8
17 | It can be seen here:
18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
19 | :param v:
20 | :param divisor:
21 | :param min_value:
22 | :return:
23 | """
24 | if min_value is None:
25 | min_value = divisor
26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
27 | # Make sure that round down does not go down by more than 10%.
28 | if new_v < 0.9 * v:
29 | new_v += divisor
30 | return new_v
31 |
32 |
33 | class ConvBNReLU(nn.Sequential):
34 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1):
35 | #padding = (kernel_size - 1) // 2
36 | super(ConvBNReLU, self).__init__(
37 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False),
38 | nn.BatchNorm2d(out_planes),
39 | nn.ReLU6(inplace=True)
40 | )
41 |
42 | def fixed_padding(kernel_size, dilation):
43 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
44 | pad_total = kernel_size_effective - 1
45 | pad_beg = pad_total // 2
46 | pad_end = pad_total - pad_beg
47 | return (pad_beg, pad_end, pad_beg, pad_end)
48 |
49 | class InvertedResidual(nn.Module):
50 | def __init__(self, inp, oup, stride, dilation, expand_ratio):
51 | super(InvertedResidual, self).__init__()
52 | self.stride = stride
53 | assert stride in [1, 2]
54 |
55 | hidden_dim = int(round(inp * expand_ratio))
56 | self.use_res_connect = self.stride == 1 and inp == oup
57 |
58 | layers = []
59 | if expand_ratio != 1:
60 | # pw
61 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
62 |
63 | layers.extend([
64 | # dw
65 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim),
66 | # pw-linear
67 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
68 | nn.BatchNorm2d(oup),
69 | ])
70 | self.conv = nn.Sequential(*layers)
71 |
72 | self.input_padding = fixed_padding( 3, dilation )
73 |
74 | def forward(self, x):
75 | x_pad = F.pad(x, self.input_padding)
76 | if self.use_res_connect:
77 | return x + self.conv(x_pad)
78 | else:
79 | return self.conv(x_pad)
80 |
81 | class MobileNetV2(nn.Module):
82 | def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
83 | """
84 | MobileNet V2 main class
85 |
86 | Args:
87 | num_classes (int): Number of classes
88 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
89 | inverted_residual_setting: Network structure
90 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
91 | Set to 1 to turn off rounding
92 | """
93 | super(MobileNetV2, self).__init__()
94 | block = InvertedResidual
95 | input_channel = 32
96 | last_channel = 1280
97 | self.output_stride = output_stride
98 | current_stride = 1
99 | if inverted_residual_setting is None:
100 | inverted_residual_setting = [
101 | # t, c, n, s
102 | [1, 16, 1, 1],
103 | [6, 24, 2, 2],
104 | [6, 32, 3, 2],
105 | [6, 64, 4, 2],
106 | [6, 96, 3, 1],
107 | [6, 160, 3, 2],
108 | [6, 320, 1, 1],
109 | ]
110 |
111 | # only check the first element, assuming user knows t,c,n,s are required
112 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
113 | raise ValueError("inverted_residual_setting should be non-empty "
114 | "or a 4-element list, got {}".format(inverted_residual_setting))
115 |
116 | # building first layer
117 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
118 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
119 | features = [ConvBNReLU(3, input_channel, stride=2)]
120 | current_stride *= 2
121 | dilation=1
122 | previous_dilation = 1
123 |
124 | # building inverted residual blocks
125 | for t, c, n, s in inverted_residual_setting:
126 | output_channel = _make_divisible(c * width_mult, round_nearest)
127 | previous_dilation = dilation
128 | if current_stride == output_stride:
129 | stride = 1
130 | dilation *= s
131 | else:
132 | stride = s
133 | current_stride *= s
134 | output_channel = int(c * width_mult)
135 |
136 | for i in range(n):
137 | if i==0:
138 | features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t))
139 | else:
140 | features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t))
141 | input_channel = output_channel
142 | # building last several layers
143 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
144 | # make it nn.Sequential
145 | self.features = nn.Sequential(*features)
146 |
147 | # building classifier
148 | self.classifier = nn.Sequential(
149 | nn.Dropout(0.2),
150 | nn.Linear(self.last_channel, num_classes),
151 | )
152 |
153 | # weight initialization
154 | for m in self.modules():
155 | if isinstance(m, nn.Conv2d):
156 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
157 | if m.bias is not None:
158 | nn.init.zeros_(m.bias)
159 | elif isinstance(m, nn.BatchNorm2d):
160 | nn.init.ones_(m.weight)
161 | nn.init.zeros_(m.bias)
162 | elif isinstance(m, nn.Linear):
163 | nn.init.normal_(m.weight, 0, 0.01)
164 | nn.init.zeros_(m.bias)
165 |
166 | def forward(self, x):
167 | x = self.features(x)
168 | x = x.mean([2, 3])
169 | x = self.classifier(x)
170 | return x
171 |
172 |
173 | def mobilenet_v2(pretrained=False, progress=True, **kwargs):
174 | """
175 | Constructs a MobileNetV2 architecture from
176 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
177 |
178 | Args:
179 | pretrained (bool): If True, returns a model pre-trained on ImageNet
180 | progress (bool): If True, displays a progress bar of the download to stderr
181 | """
182 | model = MobileNetV2(**kwargs)
183 | if pretrained:
184 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
185 | progress=progress)
186 | model.load_state_dict(state_dict)
187 | return model
188 |
--------------------------------------------------------------------------------
/datafree/models/deeplab/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models.utils import load_state_dict_from_url
4 |
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
8 | 'wide_resnet50_2', 'wide_resnet101_2']
9 |
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
21 | }
22 |
23 |
24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
25 | """3x3 convolution with padding"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27 | padding=dilation, groups=groups, bias=False, dilation=dilation)
28 |
29 |
30 | def conv1x1(in_planes, out_planes, stride=1):
31 | """1x1 convolution"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
33 |
34 |
35 | class BasicBlock(nn.Module):
36 | expansion = 1
37 |
38 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
39 | base_width=64, dilation=1, norm_layer=None):
40 | super(BasicBlock, self).__init__()
41 | if norm_layer is None:
42 | norm_layer = nn.BatchNorm2d
43 | if groups != 1 or base_width != 64:
44 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
45 | if dilation > 1:
46 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
48 | self.conv1 = conv3x3(inplanes, planes, stride)
49 | self.bn1 = norm_layer(planes)
50 | self.relu = nn.ReLU(inplace=True)
51 | self.conv2 = conv3x3(planes, planes)
52 | self.bn2 = norm_layer(planes)
53 | self.downsample = downsample
54 | self.stride = stride
55 |
56 | def forward(self, x):
57 | identity = x
58 |
59 | out = self.conv1(x)
60 | out = self.bn1(out)
61 | out = self.relu(out)
62 |
63 | out = self.conv2(out)
64 | out = self.bn2(out)
65 |
66 | if self.downsample is not None:
67 | identity = self.downsample(x)
68 |
69 | out += identity
70 | out = self.relu(out)
71 |
72 | return out
73 |
74 |
75 | class Bottleneck(nn.Module):
76 | expansion = 4
77 |
78 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
79 | base_width=64, dilation=1, norm_layer=None):
80 | super(Bottleneck, self).__init__()
81 | if norm_layer is None:
82 | norm_layer = nn.BatchNorm2d
83 | width = int(planes * (base_width / 64.)) * groups
84 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
85 | self.conv1 = conv1x1(inplanes, width)
86 | self.bn1 = norm_layer(width)
87 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
88 | self.bn2 = norm_layer(width)
89 | self.conv3 = conv1x1(width, planes * self.expansion)
90 | self.bn3 = norm_layer(planes * self.expansion)
91 | self.relu = nn.ReLU(inplace=True)
92 | self.downsample = downsample
93 | self.stride = stride
94 |
95 | def forward(self, x):
96 | identity = x
97 |
98 | out = self.conv1(x)
99 | out = self.bn1(out)
100 | out = self.relu(out)
101 |
102 | out = self.conv2(out)
103 | out = self.bn2(out)
104 | out = self.relu(out)
105 |
106 | out = self.conv3(out)
107 | out = self.bn3(out)
108 |
109 | if self.downsample is not None:
110 | identity = self.downsample(x)
111 |
112 | out += identity
113 | out = self.relu(out)
114 |
115 | return out
116 |
117 |
118 | class ResNet(nn.Module):
119 |
120 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
121 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
122 | norm_layer=None):
123 | super(ResNet, self).__init__()
124 | if norm_layer is None:
125 | norm_layer = nn.BatchNorm2d
126 | self._norm_layer = norm_layer
127 |
128 | self.inplanes = 64
129 | self.dilation = 1
130 | if replace_stride_with_dilation is None:
131 | # each element in the tuple indicates if we should replace
132 | # the 2x2 stride with a dilated convolution instead
133 | replace_stride_with_dilation = [False, False, False]
134 | if len(replace_stride_with_dilation) != 3:
135 | raise ValueError("replace_stride_with_dilation should be None "
136 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
137 | self.groups = groups
138 | self.base_width = width_per_group
139 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
140 | bias=False)
141 | self.bn1 = norm_layer(self.inplanes)
142 | self.relu = nn.ReLU(inplace=True)
143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
144 | self.layer1 = self._make_layer(block, 64, layers[0])
145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
146 | dilate=replace_stride_with_dilation[0])
147 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
148 | dilate=replace_stride_with_dilation[1])
149 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
150 | dilate=replace_stride_with_dilation[2])
151 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
152 | self.fc = nn.Linear(512 * block.expansion, num_classes)
153 |
154 | for m in self.modules():
155 | if isinstance(m, nn.Conv2d):
156 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
157 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
158 | nn.init.constant_(m.weight, 1)
159 | nn.init.constant_(m.bias, 0)
160 |
161 | # Zero-initialize the last BN in each residual branch,
162 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
163 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
164 | if zero_init_residual:
165 | for m in self.modules():
166 | if isinstance(m, Bottleneck):
167 | nn.init.constant_(m.bn3.weight, 0)
168 | elif isinstance(m, BasicBlock):
169 | nn.init.constant_(m.bn2.weight, 0)
170 |
171 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
172 | norm_layer = self._norm_layer
173 | downsample = None
174 | previous_dilation = self.dilation
175 | if dilate:
176 | self.dilation *= stride
177 | stride = 1
178 | if stride != 1 or self.inplanes != planes * block.expansion:
179 | downsample = nn.Sequential(
180 | conv1x1(self.inplanes, planes * block.expansion, stride),
181 | norm_layer(planes * block.expansion),
182 | )
183 |
184 | layers = []
185 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
186 | self.base_width, previous_dilation, norm_layer))
187 | self.inplanes = planes * block.expansion
188 | for _ in range(1, blocks):
189 | layers.append(block(self.inplanes, planes, groups=self.groups,
190 | base_width=self.base_width, dilation=self.dilation,
191 | norm_layer=norm_layer))
192 |
193 | return nn.Sequential(*layers)
194 |
195 | def forward(self, x):
196 | x = self.conv1(x)
197 | x = self.bn1(x)
198 | x = self.relu(x)
199 | x = self.maxpool(x)
200 |
201 | x = self.layer1(x)
202 | x = self.layer2(x)
203 | x = self.layer3(x)
204 | x = self.layer4(x)
205 |
206 | x = self.avgpool(x)
207 | x = torch.flatten(x, 1)
208 | x = self.fc(x)
209 |
210 | return x
211 |
212 |
213 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
214 | model = ResNet(block, layers, **kwargs)
215 | if pretrained:
216 | state_dict = load_state_dict_from_url(model_urls[arch],
217 | progress=progress)
218 | model.load_state_dict(state_dict)
219 | return model
220 |
221 |
222 | def resnet18(pretrained=False, progress=True, **kwargs):
223 | r"""ResNet-18 model from
224 | `"Deep Residual Learning for Image Recognition" `_
225 |
226 | Args:
227 | pretrained (bool): If True, returns a model pre-trained on ImageNet
228 | progress (bool): If True, displays a progress bar of the download to stderr
229 | """
230 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
231 | **kwargs)
232 |
233 |
234 | def resnet34(pretrained=False, progress=True, **kwargs):
235 | r"""ResNet-34 model from
236 | `"Deep Residual Learning for Image Recognition" `_
237 |
238 | Args:
239 | pretrained (bool): If True, returns a model pre-trained on ImageNet
240 | progress (bool): If True, displays a progress bar of the download to stderr
241 | """
242 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
243 | **kwargs)
244 |
245 |
246 | def resnet50(pretrained=False, progress=True, **kwargs):
247 | r"""ResNet-50 model from
248 | `"Deep Residual Learning for Image Recognition" `_
249 |
250 | Args:
251 | pretrained (bool): If True, returns a model pre-trained on ImageNet
252 | progress (bool): If True, displays a progress bar of the download to stderr
253 | """
254 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
255 | **kwargs)
256 |
257 |
258 | def resnet101(pretrained=False, progress=True, **kwargs):
259 | r"""ResNet-101 model from
260 | `"Deep Residual Learning for Image Recognition" `_
261 |
262 | Args:
263 | pretrained (bool): If True, returns a model pre-trained on ImageNet
264 | progress (bool): If True, displays a progress bar of the download to stderr
265 | """
266 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
267 | **kwargs)
268 |
269 |
270 | def resnet152(pretrained=False, progress=True, **kwargs):
271 | r"""ResNet-152 model from
272 | `"Deep Residual Learning for Image Recognition" `_
273 |
274 | Args:
275 | pretrained (bool): If True, returns a model pre-trained on ImageNet
276 | progress (bool): If True, displays a progress bar of the download to stderr
277 | """
278 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
279 | **kwargs)
280 |
281 |
282 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
283 | r"""ResNeXt-50 32x4d model from
284 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
285 |
286 | Args:
287 | pretrained (bool): If True, returns a model pre-trained on ImageNet
288 | progress (bool): If True, displays a progress bar of the download to stderr
289 | """
290 | kwargs['groups'] = 32
291 | kwargs['width_per_group'] = 4
292 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
293 | pretrained, progress, **kwargs)
294 |
295 |
296 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
297 | r"""ResNeXt-101 32x8d model from
298 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
299 |
300 | Args:
301 | pretrained (bool): If True, returns a model pre-trained on ImageNet
302 | progress (bool): If True, displays a progress bar of the download to stderr
303 | """
304 | kwargs['groups'] = 32
305 | kwargs['width_per_group'] = 8
306 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
307 | pretrained, progress, **kwargs)
308 |
309 |
310 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
311 | r"""Wide ResNet-50-2 model from
312 | `"Wide Residual Networks" `_
313 |
314 | The model is the same as ResNet except for the bottleneck number of channels
315 | which is twice larger in every block. The number of channels in outer 1x1
316 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
317 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
318 |
319 | Args:
320 | pretrained (bool): If True, returns a model pre-trained on ImageNet
321 | progress (bool): If True, displays a progress bar of the download to stderr
322 | """
323 | kwargs['width_per_group'] = 64 * 2
324 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
325 | pretrained, progress, **kwargs)
326 |
327 |
328 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
329 | r"""Wide ResNet-101-2 model from
330 | `"Wide Residual Networks" `_
331 |
332 | The model is the same as ResNet except for the bottleneck number of channels
333 | which is twice larger in every block. The number of channels in outer 1x1
334 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
335 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
336 |
337 | Args:
338 | pretrained (bool): If True, returns a model pre-trained on ImageNet
339 | progress (bool): If True, displays a progress bar of the download to stderr
340 | """
341 | kwargs['width_per_group'] = 64 * 2
342 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
343 | pretrained, progress, **kwargs)
344 |
--------------------------------------------------------------------------------
/datafree/models/deeplab/modeling.py:
--------------------------------------------------------------------------------
1 | from .utils import IntermediateLayerGetter
2 | from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3
3 | from .backbone import resnet
4 | from .backbone import mobilenetv2
5 |
6 | def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
7 |
8 | if output_stride==8:
9 | replace_stride_with_dilation=[False, True, True]
10 | aspp_dilate = [12, 24, 36]
11 | else:
12 | replace_stride_with_dilation=[False, False, True]
13 | aspp_dilate = [6, 12, 18]
14 |
15 | backbone = resnet.__dict__[backbone_name](
16 | pretrained=pretrained_backbone,
17 | replace_stride_with_dilation=replace_stride_with_dilation)
18 |
19 | inplanes = 2048
20 | low_level_planes = 256
21 |
22 | if name=='deeplabv3plus':
23 | return_layers = {'layer4': 'out', 'layer1': 'low_level'}
24 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
25 | elif name=='deeplabv3':
26 | return_layers = {'layer4': 'out'}
27 | classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
28 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
29 |
30 | model = DeepLabV3(backbone, classifier)
31 | return model
32 |
33 | def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
34 | if output_stride==8:
35 | aspp_dilate = [12, 24, 36]
36 | else:
37 | aspp_dilate = [6, 12, 18]
38 |
39 | backbone = mobilenetv2.mobilenet_v2(pretrained=pretrained_backbone, output_stride=output_stride)
40 |
41 | # rename layers
42 | backbone.low_level_features = backbone.features[0:4]
43 | backbone.high_level_features = backbone.features[4:-1]
44 | backbone.features = None
45 | backbone.classifier = None
46 |
47 | inplanes = 320
48 | low_level_planes = 24
49 |
50 | if name=='deeplabv3plus':
51 | return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'}
52 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
53 | elif name=='deeplabv3':
54 | return_layers = {'high_level_features': 'out'}
55 | classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
56 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
57 |
58 | model = DeepLabV3(backbone, classifier)
59 | return model
60 |
61 | def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):
62 |
63 | if backbone=='mobilenetv2':
64 | model = _segm_mobilenet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
65 | elif backbone.startswith('resnet'):
66 | model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
67 | else:
68 | raise NotImplementedError
69 | return model
70 |
71 |
72 | # Deeplab v3
73 |
74 | def deeplabv3_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
75 | """Constructs a DeepLabV3 model with a ResNet-50 backbone.
76 |
77 | Args:
78 | num_classes (int): number of classes.
79 | output_stride (int): output stride for deeplab.
80 | pretrained_backbone (bool): If True, use the pretrained backbone.
81 | """
82 | return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
83 |
84 | def deeplabv3_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
85 | """Constructs a DeepLabV3 model with a ResNet-101 backbone.
86 |
87 | Args:
88 | num_classes (int): number of classes.
89 | output_stride (int): output stride for deeplab.
90 | pretrained_backbone (bool): If True, use the pretrained backbone.
91 | """
92 | return _load_model('deeplabv3', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
93 |
94 | def deeplabv3_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True, **kwargs):
95 | """Constructs a DeepLabV3 model with a MobileNetv2 backbone.
96 |
97 | Args:
98 | num_classes (int): number of classes.
99 | output_stride (int): output stride for deeplab.
100 | pretrained_backbone (bool): If True, use the pretrained backbone.
101 | """
102 | return _load_model('deeplabv3', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
103 |
104 |
105 | # Deeplab v3+
106 |
107 | def deeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
108 | """Constructs a DeepLabV3 model with a ResNet-50 backbone.
109 |
110 | Args:
111 | num_classes (int): number of classes.
112 | output_stride (int): output stride for deeplab.
113 | pretrained_backbone (bool): If True, use the pretrained backbone.
114 | """
115 | return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
116 |
117 |
118 | def deeplabv3plus_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
119 | """Constructs a DeepLabV3+ model with a ResNet-101 backbone.
120 |
121 | Args:
122 | num_classes (int): number of classes.
123 | output_stride (int): output stride for deeplab.
124 | pretrained_backbone (bool): If True, use the pretrained backbone.
125 | """
126 | return _load_model('deeplabv3plus', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
127 |
128 |
129 | def deeplabv3plus_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True):
130 | """Constructs a DeepLabV3+ model with a MobileNetv2 backbone.
131 |
132 | Args:
133 | num_classes (int): number of classes.
134 | output_stride (int): output stride for deeplab.
135 | pretrained_backbone (bool): If True, use the pretrained backbone.
136 | """
137 | return _load_model('deeplabv3plus', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
--------------------------------------------------------------------------------
/datafree/models/deeplab/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from collections import OrderedDict
6 |
7 | class _SimpleSegmentationModel(nn.Module):
8 | def __init__(self, backbone, classifier):
9 | super(_SimpleSegmentationModel, self).__init__()
10 | self.backbone = backbone
11 | self.classifier = classifier
12 |
13 | def forward(self, x):
14 | input_shape = x.shape[-2:]
15 | features = self.backbone(x)
16 | x = self.classifier(features)
17 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
18 | return x
19 |
20 |
21 | class IntermediateLayerGetter(nn.ModuleDict):
22 | """
23 | Module wrapper that returns intermediate layers from a model
24 |
25 | It has a strong assumption that the modules have been registered
26 | into the model in the same order as they are used.
27 | This means that one should **not** reuse the same nn.Module
28 | twice in the forward if you want this to work.
29 |
30 | Additionally, it is only able to query submodules that are directly
31 | assigned to the model. So if `model` is passed, `model.feature1` can
32 | be returned, but not `model.feature1.layer2`.
33 |
34 | Arguments:
35 | model (nn.Module): model on which we will extract the features
36 | return_layers (Dict[name, new_name]): a dict containing the names
37 | of the modules for which the activations will be returned as
38 | the key of the dict, and the value of the dict is the name
39 | of the returned activation (which the user can specify).
40 |
41 | Examples::
42 |
43 | >>> m = torchvision.models.resnet18(pretrained=True)
44 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
45 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
46 | >>> {'layer1': 'feat1', 'layer3': 'feat2'})
47 | >>> out = new_m(torch.rand(1, 3, 224, 224))
48 | >>> print([(k, v.shape) for k, v in out.items()])
49 | >>> [('feat1', torch.Size([1, 64, 56, 56])),
50 | >>> ('feat2', torch.Size([1, 256, 14, 14]))]
51 | """
52 | def __init__(self, model, return_layers):
53 | if not set(return_layers).issubset([name for name, _ in model.named_children()]):
54 | raise ValueError("return_layers are not present in model")
55 |
56 | orig_return_layers = return_layers
57 | return_layers = {k: v for k, v in return_layers.items()}
58 | layers = OrderedDict()
59 | for name, module in model.named_children():
60 | layers[name] = module
61 | if name in return_layers:
62 | del return_layers[name]
63 | if not return_layers:
64 | break
65 |
66 | super(IntermediateLayerGetter, self).__init__(layers)
67 | self.return_layers = orig_return_layers
68 |
69 | def forward(self, x):
70 | out = OrderedDict()
71 | for name, module in self.named_children():
72 | x = module(x)
73 | if name in self.return_layers:
74 | out_name = self.return_layers[name]
75 | out[out_name] = x
76 | return out
77 |
--------------------------------------------------------------------------------
/datafree/models/generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class Flatten(nn.Module):
6 | def __init__(self):
7 | super(Flatten, self).__init__()
8 |
9 | def forward(self, x):
10 | return torch.flatten(x, 1)
11 |
12 | class Generator(nn.Module):
13 | def __init__(self, nz=100, ngf=64, img_size=32, nc=3):
14 | super(Generator, self).__init__()
15 |
16 | self.init_size = img_size // 4
17 | self.l1 = nn.Sequential(nn.Linear(nz, ngf * 2 * self.init_size ** 2))
18 |
19 | self.conv_blocks = nn.Sequential(
20 | nn.BatchNorm2d(ngf * 2),
21 | nn.Upsample(scale_factor=2),
22 |
23 | nn.Conv2d(ngf*2, ngf*2, 3, stride=1, padding=1, bias=False),
24 | nn.BatchNorm2d(ngf*2),
25 | nn.LeakyReLU(0.2, inplace=True),
26 | nn.Upsample(scale_factor=2),
27 |
28 | nn.Conv2d(ngf*2, ngf, 3, stride=1, padding=1, bias=False),
29 | nn.BatchNorm2d(ngf),
30 | nn.LeakyReLU(0.2, inplace=True),
31 | nn.Conv2d(ngf, nc, 3, stride=1, padding=1),
32 | nn.Sigmoid(),
33 | )
34 |
35 | def forward(self, z):
36 | out = self.l1(z)
37 | out = out.view(out.shape[0], -1, self.init_size, self.init_size)
38 | img = self.conv_blocks(out)
39 | return img
40 |
41 |
42 | class LargeGenerator(nn.Module):
43 | def __init__(self, nz=100, ngf=64, img_size=32, nc=3):
44 | super(LargeGenerator, self).__init__()
45 |
46 | self.init_size = img_size // 4
47 | self.l1 = nn.Sequential(nn.Linear(nz, ngf * 4 * self.init_size ** 2))
48 |
49 | self.conv_blocks = nn.Sequential(
50 | nn.BatchNorm2d(ngf * 4),
51 | nn.Upsample(scale_factor=2),
52 |
53 | nn.Conv2d(ngf*4, ngf*2, 3, stride=1, padding=1, bias=False),
54 | nn.BatchNorm2d(ngf*2),
55 | nn.LeakyReLU(0.2, inplace=True),
56 | nn.Upsample(scale_factor=2),
57 |
58 | nn.Conv2d(ngf*2, ngf, 3, stride=1, padding=1, bias=False),
59 | nn.BatchNorm2d(ngf),
60 | nn.LeakyReLU(0.2, inplace=True),
61 | nn.Conv2d(ngf, nc, 3, stride=1, padding=1),
62 | nn.Sigmoid(),
63 | )
64 |
65 | def forward(self, z):
66 | out = self.l1(z)
67 | out = out.view(out.shape[0], -1, self.init_size, self.init_size)
68 | img = self.conv_blocks(out)
69 | return img
70 |
71 |
72 | class DCGAN_Generator(nn.Module):
73 | """ Generator from DCGAN: https://arxiv.org/abs/1511.06434
74 | """
75 | def __init__(self, nz=100, ngf=64, nc=3, img_size=64, slope=0.2):
76 | super(DCGAN_Generator, self).__init__()
77 | self.nz = nz
78 | if isinstance(img_size, (list, tuple)):
79 | self.init_size = ( img_size[0]//16, img_size[1]//16 )
80 | else:
81 | self.init_size = ( img_size // 16, img_size // 16)
82 |
83 | self.project = nn.Sequential(
84 | Flatten(),
85 | nn.Linear(nz, ngf*8*self.init_size[0]*self.init_size[1]),
86 | )
87 |
88 | self.main = nn.Sequential(
89 | nn.BatchNorm2d(ngf*8),
90 |
91 | nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
92 | nn.BatchNorm2d(ngf*4),
93 | nn.LeakyReLU(slope, inplace=True),
94 | # 2x
95 |
96 | nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
97 | nn.BatchNorm2d(ngf*2),
98 | nn.LeakyReLU(slope, inplace=True),
99 | # 4x
100 |
101 | nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
102 | nn.BatchNorm2d(ngf),
103 | nn.LeakyReLU(slope, inplace=True),
104 | # 8x
105 |
106 | nn.ConvTranspose2d(ngf, ngf, 4, 2, 1, bias=False),
107 | nn.BatchNorm2d(ngf),
108 | nn.LeakyReLU(slope, inplace=True),
109 | # 16x
110 |
111 | nn.Conv2d(ngf, nc, 3, 1,1),
112 | nn.Sigmoid(),
113 | #nn.Sigmoid()
114 | )
115 |
116 | def forward(self, z):
117 | proj = self.project(z)
118 | proj = proj.view(proj.shape[0], -1, self.init_size[0], self.init_size[1])
119 | output = self.main(proj)
120 | return output
121 |
122 | class DCGAN_CondGenerator(nn.Module):
123 | """ Generator from DCGAN: https://arxiv.org/abs/1511.06434
124 | """
125 | def __init__(self, num_classes, nz=100, n_emb=50, ngf=64, nc=3, img_size=64, slope=0.2):
126 | super(DCGAN_CondGenerator, self).__init__()
127 | self.nz = nz
128 | self.emb = nn.Embedding(num_classes, n_emb)
129 | if isinstance(img_size, (list, tuple)):
130 | self.init_size = ( img_size[0]//16, img_size[1]//16 )
131 | else:
132 | self.init_size = ( img_size // 16, img_size // 16)
133 |
134 | self.project = nn.Sequential(
135 | Flatten(),
136 | nn.Linear(nz+n_emb, ngf*8*self.init_size[0]*self.init_size[1]),
137 | )
138 |
139 | self.main = nn.Sequential(
140 | nn.BatchNorm2d(ngf*8),
141 |
142 | nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
143 | nn.BatchNorm2d(ngf*4),
144 | nn.LeakyReLU(slope, inplace=True),
145 | # 2x
146 |
147 | nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
148 | nn.BatchNorm2d(ngf*2),
149 | nn.LeakyReLU(slope, inplace=True),
150 | # 4x
151 |
152 | nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
153 | nn.BatchNorm2d(ngf),
154 | nn.LeakyReLU(slope, inplace=True),
155 | # 8x
156 |
157 | nn.ConvTranspose2d(ngf, ngf, 4, 2, 1, bias=False),
158 | nn.BatchNorm2d(ngf),
159 | nn.LeakyReLU(slope, inplace=True),
160 | # 16x
161 |
162 | nn.Conv2d(ngf, nc, 3, 1,1),
163 | #nn.Tanh(),
164 | nn.Sigmoid()
165 | )
166 |
167 | def forward(self, z, y):
168 | y = self.emb(y)
169 | z = torch.cat([z, y], dim=1)
170 | proj = self.project(z)
171 | proj = proj.view(proj.shape[0], -1, self.init_size[0], self.init_size[1])
172 | output = self.main(proj)
173 | return output
174 |
175 | class Discriminator(nn.Module):
176 | def __init__(self, nc=3, img_size=32):
177 | super(Discriminator, self).__init__()
178 |
179 | def discriminator_block(in_filters, out_filters, bn=True):
180 | block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
181 | if bn:
182 | block.append(nn.BatchNorm2d(out_filters, 0.8))
183 | return block
184 |
185 | self.model = nn.Sequential(
186 | *discriminator_block(nc, 16, bn=False),
187 | *discriminator_block(16, 32),
188 | *discriminator_block(32, 64),
189 | *discriminator_block(64, 128),
190 | )
191 |
192 | # The height and width of downsampled image
193 | ds_size = img_size // 2 ** 4
194 | self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
195 |
196 | def forward(self, img):
197 | out = self.model(img)
198 | out = out.view(out.shape[0], -1)
199 | validity = self.adv_layer(out)
200 | return validity
201 |
202 | class DCGAN_Discriminator(nn.Module):
203 | def __init__(self, nc=3, ndf=64):
204 | super(DCGAN_Discriminator, self).__init__()
205 | self.main = nn.Sequential(
206 | # input is (nc) x 64 x 64
207 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
208 | nn.LeakyReLU(0.2, inplace=True),
209 | # state size. (ndf) x 32 x 32
210 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
211 | nn.BatchNorm2d(ndf * 2),
212 | nn.LeakyReLU(0.2, inplace=True),
213 | # state size. (ndf*2) x 16 x 16
214 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
215 | nn.BatchNorm2d(ndf * 4),
216 | nn.LeakyReLU(0.2, inplace=True),
217 | # state size. (ndf*4) x 8 x 8
218 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
219 | nn.BatchNorm2d(ndf * 8),
220 | nn.LeakyReLU(0.2, inplace=True),
221 | # state size. (ndf*8) x 4 x 4
222 | nn.Conv2d(ndf * 8, 1, 2, 1, 0, bias=False),
223 | nn.Sigmoid()
224 | )
225 |
226 | def forward(self, input):
227 | return self.main(input)
--------------------------------------------------------------------------------
/datafree/rep_transfer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class HintLoss(nn.Module):
7 | """Convolutional regression for FitNet"""
8 | def __init__(self, s_shapes, t_shapes, use_relu=False, loss_fn=F.mse_loss):
9 | super(HintLoss, self).__init__()
10 | self.use_relu = use_relu
11 | self.loss_fn = loss_fn
12 | regs = []
13 | for s_shape, t_shape in zip(s_shapes, t_shapes):
14 | s_N, s_C, s_H, s_W = s_shape
15 | t_N, t_C, t_H, t_W = t_shape
16 | if s_H == 2 * t_H:
17 | conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1)
18 | elif s_H * 2 == t_H:
19 | conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1)
20 | elif s_H >= t_H:
21 | conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W))
22 | else:
23 | raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H))
24 | reg = [conv, nn.BatchNorm2d(t_C)]
25 | if use_relu:
26 | reg.append( nn.ReLU(inplace=True) )
27 | regs.append(nn.Sequential(*reg))
28 | self.regs = nn.ModuleList(regs)
29 |
30 | def forward(self, s_features, t_features):
31 | loss = []
32 | for reg, s_feat, t_feat in zip(self.regs, s_features, t_features):
33 | s_feat = reg(s_feat)
34 | loss.append( self.loss_fn( s_feat, t_feat ) )
35 | return loss
36 |
37 |
38 | class ABLoss(nn.Module):
39 | """Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
40 | code: https://github.com/bhheo/AB_distillation
41 | """
42 | def __init__(self, s_shapes, t_shapes, margin=1.0, use_relu=False):
43 | super(ABLoss, self).__init__()
44 |
45 | regs = []
46 | for s_shape, t_shape in zip(s_shapes, t_shapes):
47 | s_N, s_C, s_H, s_W = s_shape
48 | t_N, t_C, t_H, t_W = t_shape
49 | if s_H == 2 * t_H:
50 | conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1)
51 | elif s_H * 2 == t_H:
52 | conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1)
53 | elif s_H >= t_H:
54 | conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W))
55 | else:
56 | raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H))
57 | reg = [conv, nn.BatchNorm2d(t_C)]
58 | if use_relu:
59 | reg.append( nn.ReLU(inplace=True) )
60 | regs.append(nn.Sequential(*reg))
61 | self.regs = nn.ModuleList(regs)
62 | feat_num = len(self.regs)
63 | self.w = [2**(i-feat_num+1) for i in range(feat_num)]
64 | self.margin = margin
65 |
66 | def forward(self, s_features, t_features, reverse=False):
67 | s_features = [ reg(s_feat) for (reg, s_feat) in zip(self.regs, s_features) ]
68 | bsz = s_features[0].shape[0]
69 | losses = [self.criterion_alternative_l2(s, t, reverse=reverse) for s, t in zip(s_features, t_features)]
70 | losses = [w * l for w, l in zip(self.w, losses)]
71 | losses = [l / bsz for l in losses]
72 | losses = [l / 1000 * 3 for l in losses]
73 | return losses
74 |
75 | def criterion_alternative_l2(self, source, target, reverse):
76 | if reverse:
77 | loss = ((source - self.margin) ** 2 * ((source < self.margin) & (target <= 0)).float() +
78 | (source + self.margin) ** 2 * ((source > -self.margin) & (target > 0)).float() +
79 | (target - self.margin) ** 2 * ((target < self.margin) & (source <= 0)).float() +
80 | (target + self.margin) ** 2 * ((target > -self.margin) & (source > 0)).float())
81 | else:
82 | loss = ((source + self.margin) ** 2 * ((source > -self.margin) & (target <= 0)).float() +
83 | (source - self.margin) ** 2 * ((source <= self.margin) & (target > 0)).float())
84 | return torch.abs(loss).sum()
85 |
86 |
87 | class RKDLoss(nn.Module):
88 | """Relational Knowledge Disitllation, CVPR2019"""
89 | def __init__(self, w_d=25, w_a=50, angle=True):
90 | super(RKDLoss, self).__init__()
91 | self.w_d = w_d
92 | self.w_a = w_a
93 | self.angle = angle
94 |
95 | def forward(self, s_features, t_features):
96 | losses = []
97 | for f_s, f_t in zip(s_features, t_features):
98 | student = f_s.view(f_s.shape[0], -1)
99 | teacher = f_t.view(f_t.shape[0], -1)
100 |
101 | # RKD distance loss
102 | with torch.no_grad():
103 | t_d = self.pdist(teacher, squared=False)
104 | mean_td = t_d[t_d > 0].mean()
105 | t_d = t_d / mean_td
106 |
107 | d = self.pdist(student, squared=False)
108 | mean_d = d[d > 0].mean()
109 | d = d / mean_d
110 |
111 | loss_d = F.smooth_l1_loss(d, t_d)
112 |
113 | if self.angle:
114 | # RKD Angle loss
115 | with torch.no_grad():
116 | td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
117 | norm_td = F.normalize(td, p=2, dim=2)
118 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
119 |
120 | sd = (student.unsqueeze(0) - student.unsqueeze(1))
121 | norm_sd = F.normalize(sd, p=2, dim=2)
122 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
123 |
124 | loss_a = F.smooth_l1_loss(s_angle, t_angle)
125 | else:
126 | loss_a = 0
127 | loss = self.w_d * loss_d + self.w_a * loss_a
128 | losses.append(loss)
129 | return losses
130 |
131 | @staticmethod
132 | def pdist(e, squared=False, eps=1e-12):
133 | e_square = e.pow(2).sum(dim=1)
134 | prod = e @ e.t()
135 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
136 |
137 | if not squared:
138 | res = res.sqrt()
139 |
140 | res = res.clone()
141 | res[range(len(e)), range(len(e))] = 0
142 | return res
143 |
144 |
145 | class FSP(nn.Module):
146 | """A Gift from Knowledge Distillation:
147 | Fast Optimization, Network Minimization and Transfer Learning"""
148 | def __init__(self, s_shapes, t_shapes):
149 | super(FSP, self).__init__()
150 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
151 | s_c = [s[1] for s in s_shapes]
152 | t_c = [t[1] for t in t_shapes]
153 | if np.any(np.asarray(s_c) != np.asarray(t_c)):
154 | raise ValueError('num of channels not equal (error in FSP)')
155 |
156 | def forward(self, g_s, g_t):
157 | s_fsp = self.compute_fsp(g_s)
158 | t_fsp = self.compute_fsp(g_t)
159 | loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)]
160 | return loss_group
161 |
162 | @staticmethod
163 | def compute_loss(s, t):
164 | return (s - t).pow(2).mean()
165 |
166 | @staticmethod
167 | def compute_fsp(g):
168 | fsp_list = []
169 | for i in range(len(g) - 1):
170 | bot, top = g[i], g[i + 1]
171 | b_H, t_H = bot.shape[2], top.shape[2]
172 | if b_H > t_H:
173 | bot = F.adaptive_avg_pool2d(bot, (t_H, t_H))
174 | elif b_H < t_H:
175 | top = F.adaptive_avg_pool2d(top, (b_H, b_H))
176 | else:
177 | pass
178 | bot = bot.unsqueeze(1)
179 | top = top.unsqueeze(2)
180 | bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1)
181 | top = top.view(top.shape[0], top.shape[1], top.shape[2], -1)
182 |
183 | fsp = (bot * top).mean(-1)
184 | fsp_list.append(fsp)
185 | return fsp_list
--------------------------------------------------------------------------------
/datafree/synthesis/__init__.py:
--------------------------------------------------------------------------------
1 | from .triplet import AdvTripletSynthesizer
2 | from .contrastive import CMISynthesizer
3 | from .base import BaseSynthesis
--------------------------------------------------------------------------------
/datafree/synthesis/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from abc import ABC, abstractclassmethod
4 | from typing import Dict
5 |
6 | class BaseSynthesis(ABC):
7 | def __init__(self, teacher, student):
8 | super(BaseSynthesis, self).__init__()
9 | self.teacher = teacher
10 | self.student = student
11 |
12 | @abstractclassmethod
13 | def synthesize(self) -> Dict[str, torch.Tensor]:
14 | """ take several steps to synthesize new images and return an image dict for visualization.
15 | Returned images should be normalized to [0, 1].
16 | """
17 | pass
18 |
19 | @abstractclassmethod
20 | def sample(self, n):
21 | """ fetch a batch of training data.
22 | """
23 | pass
--------------------------------------------------------------------------------
/datafree/synthesis/contrastive.py:
--------------------------------------------------------------------------------
1 | import datafree
2 | from typing import Generator
3 | import torch
4 | from torch import optim
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import random
8 |
9 | from .base import BaseSynthesis
10 | from datafree.hooks import DeepInversionHook, InstanceMeanHook
11 | from datafree.criterions import jsdiv, get_image_prior_losses, kldiv
12 | from datafree.utils import ImagePool, DataIter, clip_images
13 | import collections
14 | from torchvision import transforms
15 | from kornia import augmentation
16 | from tqdm import tqdm
17 |
18 | class MLPHead(nn.Module):
19 | def __init__(self, dim_in, dim_feat, dim_h=None):
20 | super(MLPHead, self).__init__()
21 | if dim_h is None:
22 | dim_h = dim_in
23 |
24 | self.head = nn.Sequential(
25 | nn.Linear(dim_in, dim_h),
26 | nn.ReLU(inplace=True),
27 | nn.Linear(dim_h, dim_feat),
28 | )
29 |
30 | def forward(self, x):
31 | x = self.head(x)
32 | return F.normalize(x, dim=1, p=2)
33 |
34 | class MultiTransform:
35 | """Create two crops of the same image"""
36 | def __init__(self, transform):
37 | self.transform = transform
38 |
39 | def __call__(self, x):
40 | return [t(x) for t in self.transform]
41 |
42 | def __repr__(self):
43 | return str( self.transform )
44 |
45 |
46 | class ContrastLoss(nn.Module):
47 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
48 | It also supports the unsupervised contrastive loss in SimCLR.
49 | Adapted from https://github.com/HobbitLong/SupContrast/blob/master/losses.py"""
50 | def __init__(self, temperature=0.07, contrast_mode='all',
51 | base_temperature=0.07):
52 | super(ContrastLoss, self).__init__()
53 | self.temperature = temperature
54 | self.contrast_mode = contrast_mode
55 | self.base_temperature = base_temperature
56 |
57 | def forward(self, features, labels=None, mask=None, return_logits=False):
58 | """Compute loss for model. If both `labels` and `mask` are None,
59 | it degenerates to SimCLR unsupervised loss:
60 | https://arxiv.org/pdf/2002.05709.pdf
61 | Args:
62 | features: hidden vector of shape [bsz, n_views, ...].
63 | labels: ground truth of shape [bsz].
64 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
65 | has the same class as sample i. Can be asymmetric.
66 | Returns:
67 | A loss scalar.
68 | """
69 | device = (torch.device('cuda')
70 | if features.is_cuda
71 | else torch.device('cpu'))
72 |
73 | if len(features.shape) < 3:
74 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
75 | 'at least 3 dimensions are required')
76 | if len(features.shape) > 3:
77 | features = features.view(features.shape[0], features.shape[1], -1)
78 |
79 | batch_size = features.shape[0]
80 | if labels is not None and mask is not None:
81 | raise ValueError('Cannot define both `labels` and `mask`')
82 | elif labels is None and mask is None:
83 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
84 | elif labels is not None:
85 | labels = labels.contiguous().view(-1, 1)
86 | if labels.shape[0] != batch_size:
87 | raise ValueError('Num of labels does not match num of features')
88 | mask = torch.eq(labels, labels.T).float().to(device)
89 | else:
90 | mask = mask.float().to(device)
91 |
92 | contrast_count = features.shape[1]
93 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
94 | if self.contrast_mode == 'one':
95 | anchor_feature = features[:, 0]
96 | anchor_count = 1
97 | elif self.contrast_mode == 'all':
98 | anchor_feature = contrast_feature
99 | anchor_count = contrast_count
100 | else:
101 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
102 |
103 | # compute logits
104 | anchor_dot_contrast = torch.div(
105 | torch.matmul(anchor_feature, contrast_feature.T),
106 | self.temperature)
107 | # for numerical stability
108 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
109 | logits = anchor_dot_contrast - logits_max.detach()
110 |
111 | # tile mask
112 | mask = mask.repeat(anchor_count, contrast_count)
113 | # mask-out self-contrast cases
114 | logits_mask = torch.scatter(
115 | torch.ones_like(mask),
116 | 1,
117 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
118 | 0
119 | )
120 | mask = mask * logits_mask
121 |
122 | # compute log_prob
123 | exp_logits = torch.exp(logits) * logits_mask
124 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
125 |
126 | # compute mean of log-likelihood over positive
127 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
128 | # loss
129 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
130 | loss = loss.view(anchor_count, batch_size)
131 |
132 | if return_logits:
133 | return loss, anchor_dot_contrast
134 | return loss
135 |
136 |
137 | class MemoryBank(object):
138 | def __init__(self, device, max_size=4096, dim_feat=512):
139 | self.device = device
140 | self.data = torch.randn( max_size, dim_feat ).to(device)
141 | self._ptr = 0
142 | self.n_updates = 0
143 |
144 | self.max_size = max_size
145 | self.dim_feat = dim_feat
146 |
147 | def add(self, feat):
148 | feat = feat.to(self.device)
149 | n, c = feat.shape
150 | assert self.dim_feat==c and self.max_size % n==0, "%d, %d"%(self.dim_feat, c, self.max_size, n)
151 | self.data[self._ptr:self._ptr+n] = feat.detach()
152 | self._ptr = (self._ptr+n) % (self.max_size)
153 | self.n_updates+=n
154 |
155 | def get_data(self, k=None, index=None):
156 | if k is None:
157 | k = self.max_size
158 |
159 | if self.n_updates>self.max_size:
160 | if index is None:
161 | index = random.sample(list(range(self.max_size)), k=k)
162 | return self.data[index], index
163 | else:
164 | #return self.data[:self._ptr]
165 | if index is None:
166 | index = random.sample(list(range(self._ptr)), k=min(k, self._ptr))
167 | return self.data[index], index
168 |
169 | def reset_model(model):
170 | for m in model.modules():
171 | if isinstance(m, (nn.ConvTranspose2d, nn.Linear, nn.Conv2d)):
172 | nn.init.normal_(m.weight, 0.0, 0.02)
173 | if m.bias is not None:
174 | nn.init.constant_(m.bias, 0)
175 | if isinstance(m, (nn.BatchNorm2d)):
176 | nn.init.normal_(m.weight, 1.0, 0.02)
177 | nn.init.constant_(m.bias, 0)
178 |
179 | class CMISynthesizer(BaseSynthesis):
180 | def __init__(self, teacher, student, generator, nz, num_classes, img_size,
181 | feature_layers=None, bank_size=40960, n_neg=4096, head_dim=128, init_dataset=None,
182 | iterations=100, lr_g=0.1, progressive_scale=False,
183 | synthesis_batch_size=128, sample_batch_size=128,
184 | adv=0.0, bn=1, oh=1, cr=0.8, cr_T=0.1,
185 | save_dir='run/cmi', transform=None,
186 | autocast=None, use_fp16=False,
187 | normalizer=None, device='cpu', distributed=False):
188 | super(CMISynthesizer, self).__init__(teacher, student)
189 | self.save_dir = save_dir
190 | self.img_size = img_size
191 | self.iterations = iterations
192 | self.lr_g = lr_g
193 | self.progressive_scale = progressive_scale
194 | self.nz = nz
195 | self.n_neg = n_neg
196 | self.adv = adv
197 | self.bn = bn
198 | self.oh = oh
199 | self.num_classes = num_classes
200 | self.distributed = distributed
201 | self.synthesis_batch_size = synthesis_batch_size
202 | self.sample_batch_size = sample_batch_size
203 | self.bank_size = bank_size
204 | self.init_dataset = init_dataset
205 |
206 | self.use_fp16 = use_fp16
207 | self.autocast = autocast # for FP16
208 | self.normalizer = normalizer
209 | self.data_pool = ImagePool(root=self.save_dir)
210 | self.transform = transform
211 | self.data_iter = None
212 |
213 | self.cr = cr
214 | self.cr_T = cr_T
215 | self.cmi_hooks = []
216 | if feature_layers is not None:
217 | for layer in feature_layers:
218 | self.cmi_hooks.append( InstanceMeanHook(layer) )
219 | else:
220 | for m in teacher.modules():
221 | if isinstance(m, nn.BatchNorm2d):
222 | self.cmi_hooks.append( InstanceMeanHook(m) )
223 |
224 | with torch.no_grad():
225 | teacher.eval()
226 | fake_inputs = torch.randn(size=(1, *img_size), device=device)
227 | _ = teacher(fake_inputs)
228 | cmi_feature = torch.cat([ h.instance_mean for h in self.cmi_hooks ], dim=1)
229 | print("CMI dims: %d"%(cmi_feature.shape[1]))
230 | del fake_inputs
231 |
232 | self.generator = generator.to(device).train()
233 | # local and global bank
234 | self.mem_bank = MemoryBank('cpu', max_size=self.bank_size, dim_feat=2*cmi_feature.shape[1]) # local + global
235 |
236 | self.head = MLPHead(cmi_feature.shape[1], head_dim).to(device).train()
237 | self.optimizer_head = torch.optim.Adam(self.head.parameters(), lr=self.lr_g)
238 |
239 | self.device = device
240 | self.hooks = []
241 | for m in teacher.modules():
242 | if isinstance(m, nn.BatchNorm2d):
243 | self.hooks.append( DeepInversionHook(m) )
244 |
245 | self.aug = MultiTransform([
246 | # global view
247 | transforms.Compose([
248 | augmentation.RandomCrop(size=[self.img_size[-2], self.img_size[-1]], padding=4),
249 | augmentation.RandomHorizontalFlip(),
250 | normalizer,
251 | ]),
252 | # local view
253 | transforms.Compose([
254 | augmentation.RandomResizedCrop(size=[self.img_size[-2], self.img_size[-1]], scale=[0.25, 1.0]),
255 | augmentation.RandomHorizontalFlip(),
256 | normalizer,
257 | ]),
258 | ])
259 |
260 | #self.contrast_loss = ContrastLoss(temperature=self.cr_T, contrast_mode='one')
261 |
262 | def synthesize(self, targets=None):
263 | self.student.eval()
264 | self.teacher.eval()
265 | best_cost = 1e6
266 |
267 | #inputs = torch.randn( size=(self.synthesis_batch_size, *self.img_size), device=self.device ).requires_grad_()
268 | best_inputs = None
269 | z = torch.randn(size=(self.synthesis_batch_size, self.nz), device=self.device).requires_grad_()
270 | if targets is None:
271 | targets = torch.randint(low=0, high=self.num_classes, size=(self.synthesis_batch_size,))
272 | targets = targets.sort()[0] # sort for better visualization
273 | targets = targets.to(self.device)
274 |
275 | reset_model(self.generator)
276 | optimizer = torch.optim.Adam([{'params': self.generator.parameters()}, {'params': [z]}], self.lr_g, betas=[0.5, 0.999])
277 | #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.iterations, eta_min=0.1*self.lr)
278 | for it in tqdm(range(self.iterations)):
279 | inputs = self.generator(z)
280 | global_view, local_view = self.aug(inputs) # crop and normalize
281 |
282 | #############################################
283 | # Inversion Loss
284 | #############################################
285 | t_out = self.teacher(global_view)
286 | loss_bn = sum([h.r_feature for h in self.hooks])
287 | loss_oh = F.cross_entropy( t_out, targets )
288 | if self.adv>0:
289 | s_out = self.student(global_view)
290 | mask = (s_out.max(1)[1]==t_out.max(1)[1]).float()
291 | loss_adv = -(kldiv(s_out, t_out, reduction='none').sum(1) * mask).mean() # decision adversarial distillation
292 | else:
293 | loss_adv = loss_oh.new_zeros(1)
294 | loss_inv = self.bn * loss_bn + self.oh * loss_oh + self.adv * loss_adv
295 |
296 | #############################################
297 | # Contrastive Loss
298 | #############################################
299 | global_feature = torch.cat([ h.instance_mean for h in self.cmi_hooks ], dim=1)
300 | _ = self.teacher(local_view)
301 | local_feature = torch.cat([ h.instance_mean for h in self.cmi_hooks ], dim=1)
302 | cached_feature, _ = self.mem_bank.get_data(self.n_neg)
303 | cached_local_feature, cached_global_feature = torch.chunk(cached_feature.to(self.device), chunks=2, dim=1)
304 |
305 | proj_feature = self.head( torch.cat([local_feature, cached_local_feature, global_feature, cached_global_feature], dim=0) )
306 | proj_local_feature, proj_global_feature = torch.chunk(proj_feature, chunks=2, dim=0)
307 |
308 | # https://github.com/HobbitLong/SupContrast/blob/master/losses.py
309 | #cr_feature = torch.cat( [proj_local_feature.unsqueeze(1), proj_global_feature.unsqueeze(1).detach()], dim=1 )
310 | #loss_cr = self.contrast_loss(cr_feature)
311 |
312 | # Note that the cross entropy loss will be divided by the total batch size (current batch + cached batch)
313 | # we split the cross entropy loss to avoid too small gradients w.r.t the generator
314 | #if self.mem_bank.n_updates>0:
315 | # 1. gradient from current batch + 2. gradient from cached data
316 | # loss_cr = loss_cr[:, :self.synthesis_batch_size].mean() + loss_cr[:, self.synthesis_batch_size:].mean()
317 | #else: # 1. gradients only come from current batch
318 | # loss_cr = loss_cr.mean()
319 |
320 | # A naive implementation of contrastive loss
321 | cr_logits = torch.mm(proj_local_feature, proj_global_feature.detach().T) / self.cr_T # (N + N') x (N + N')
322 | cr_labels = torch.arange(start=0, end=len(cr_logits), device=self.device)
323 | loss_cr = F.cross_entropy( cr_logits, cr_labels, reduction='none') #(N + N')
324 | if self.mem_bank.n_updates>0:
325 | loss_cr = loss_cr[:self.synthesis_batch_size].mean() + loss_cr[self.synthesis_batch_size:].mean()
326 | else:
327 | loss_cr = loss_cr.mean()
328 |
329 | loss = self.cr * loss_cr + loss_inv
330 | with torch.no_grad():
331 | if best_cost > loss.item() or best_inputs is None:
332 | best_cost = loss.item()
333 | best_inputs = inputs.data
334 | best_features = torch.cat([local_feature.data, global_feature.data], dim=1).data
335 | optimizer.zero_grad()
336 | self.optimizer_head.zero_grad()
337 | loss.backward()
338 | optimizer.step()
339 | self.optimizer_head.step()
340 |
341 | self.student.train()
342 | # save best inputs and reset data iter
343 | self.data_pool.add( best_inputs )
344 | self.mem_bank.add( best_features )
345 |
346 | dst = self.data_pool.get_dataset(transform=self.transform)
347 | if self.init_dataset is not None:
348 | init_dst = datafree.utils.UnlabeledImageDataset(self.init_dataset, transform=self.transform)
349 | dst = torch.utils.data.ConcatDataset([dst, init_dst])
350 | if self.distributed:
351 | train_sampler = torch.utils.data.distributed.DistributedSampler(dst)
352 | else:
353 | train_sampler = None
354 | loader = torch.utils.data.DataLoader(
355 | dst, batch_size=self.sample_batch_size, shuffle=(train_sampler is None),
356 | num_workers=4, pin_memory=True, sampler=train_sampler)
357 | self.data_iter = DataIter(loader)
358 | return {"synthetic": best_inputs}
359 |
360 | def sample(self):
361 | return self.data_iter.next()
--------------------------------------------------------------------------------
/datafree/synthesis/triplet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from losses import TriLoss
6 |
7 | from .base import BaseSynthesis
8 | from datafree.hooks import DeepInversionHook
9 | from datafree.criterions import kldiv
10 | from datafree.utils import ImagePool, DataIter
11 | from torchvision import transforms
12 | from kornia import augmentation
13 | from tqdm import tqdm
14 |
15 |
16 | def reset_model(model):
17 | for m in model.modules():
18 | if isinstance(m, (nn.ConvTranspose2d, nn.Linear, nn.Conv2d)):
19 | nn.init.normal_(m.weight, 0.0, 0.02)
20 | if m.bias is not None:
21 | nn.init.constant_(m.bias, 0)
22 | if isinstance(m, (nn.BatchNorm2d)):
23 | nn.init.normal_(m.weight, 1.0, 0.02)
24 | nn.init.constant_(m.bias, 0)
25 |
26 |
27 | class AdvTripletSynthesizer(BaseSynthesis):
28 | def __init__(self, teacher, student, generator, pair_sample, nz, num_classes, img_size,
29 | start_layer, end_layer, iterations=100, lr_g=0.1, progressive_scale=False,
30 | synthesis_batch_size=128, sample_batch_size=128,
31 | adv=0.0, bn=1, oh=1, triplet=0.0,
32 | save_dir='run/cmi', transform=None,
33 | normalizer=None, device='cpu', distributed=False,
34 | triplet_target='teacher', balanced_sampling=False):
35 | super(AdvTripletSynthesizer, self).__init__(teacher, student)
36 | self.save_dir = save_dir
37 | self.img_size = img_size
38 | self.start_layer = start_layer
39 | self.end_layer = end_layer
40 | self.iterations = iterations
41 | self.lr_g = lr_g
42 | self.progressive_scale = progressive_scale
43 | self.nz = nz
44 | self.adv = adv
45 | self.bn = bn
46 | self.oh = oh
47 | self.triplet = triplet
48 | self.compute_triplet_loss = TriLoss(balanced_sampling=balanced_sampling)
49 | self.num_classes = num_classes
50 | self.distributed = distributed
51 | self.synthesis_batch_size = synthesis_batch_size
52 | self.sample_batch_size = sample_batch_size
53 | self.triplet_target = triplet_target
54 |
55 | self.normalizer = normalizer
56 | self.data_pool = ImagePool(root=self.save_dir)
57 | self.transform = transform
58 | self.data_iter = None
59 | self.generator = generator.to(device).train()
60 | # local and global bank
61 |
62 | self.device = device
63 | self.hooks = []
64 | for m in teacher.modules():
65 | if isinstance(m, nn.BatchNorm2d):
66 | self.hooks.append(DeepInversionHook(m))
67 |
68 | self.aug = transforms.Compose([
69 | augmentation.RandomCrop(
70 | size=[self.img_size[-2], self.img_size[-1]], padding=4),
71 | augmentation.RandomHorizontalFlip(),
72 | normalizer,
73 | ])
74 | self.pair_sample = pair_sample
75 |
76 | def synthesize(self, targets=None):
77 | self.student.eval()
78 | self.teacher.eval()
79 | best_cost = 1e6
80 |
81 | #inputs = torch.randn( size=(self.synthesis_batch_size, *self.img_size), device=self.device ).requires_grad_()
82 | best_inputs = None
83 | z = torch.randn(size=(self.synthesis_batch_size, self.nz),
84 | device=self.device).requires_grad_()
85 | if targets is None:
86 | targets = torch.randint(
87 | low=0, high=self.num_classes, size=(self.synthesis_batch_size,))
88 | targets = targets.sort()[0] # sort for better visualization
89 | targets = targets.to(self.device)
90 |
91 | reset_model(self.generator)
92 | optimizer = torch.optim.Adam([{'params': self.generator.parameters()}, {
93 | 'params': [z]}], self.lr_g, betas=[0.5, 0.999])
94 | #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.iterations, eta_min=0.1*self.lr)
95 | for _ in tqdm(range(self.iterations)):
96 | inputs = self.generator(z)
97 | global_view = self.aug(inputs) # crop and normalize
98 |
99 | #############################################
100 | # Inversion Loss
101 | #############################################
102 | t_out, _, t_layers = self.teacher(
103 | global_view, return_features=True)
104 | s_out, _, s_layers = self.student(
105 | global_view, return_features=True)
106 | loss_bn = sum([h.r_feature for h in self.hooks])
107 | loss_oh = F.cross_entropy(t_out, targets)
108 |
109 | if self.adv > 0:
110 | mask = (s_out.max(1)[1] == t_out.max(1)[1]).float()
111 | # decision adversarial distillation
112 | loss_adv = - \
113 | (kldiv(s_out, t_out, reduction='none').sum(1) * mask).mean()
114 | else:
115 | loss_adv = loss_oh.new_zeros(1)
116 |
117 | if self.triplet_target == 'teacher':
118 | triplet_layers = t_layers
119 | elif self.triplet_target == 'student':
120 | triplet_layers = s_layers
121 | else:
122 | raise NotImplementedError()
123 |
124 | if self.triplet > 0:
125 | loss_tri = self.compute_triplet_loss(
126 | triplet_layers[self.start_layer:self.end_layer], t_out, torch.argmax(t_out, dim=-1))
127 | else:
128 | loss_tri = loss_oh.new_zeros(1)
129 |
130 | loss = self.bn * loss_bn + self.oh * loss_oh + \
131 | self.adv * loss_adv + self.triplet * loss_tri
132 |
133 | with torch.no_grad():
134 | if best_cost > loss.item() or best_inputs is None:
135 | best_cost = loss.item()
136 | best_inputs = inputs.data
137 | optimizer.zero_grad()
138 | loss.backward()
139 | optimizer.step()
140 |
141 | self.student.train()
142 | self.data_pool.add(best_inputs)
143 |
144 | dst = self.data_pool.get_dataset(
145 | transform=self.transform, pair_sample=self.pair_sample)
146 | if self.distributed:
147 | train_sampler = torch.utils.data.distributed.DistributedSampler(
148 | dst)
149 | else:
150 | train_sampler = None
151 | loader = torch.utils.data.DataLoader(
152 | dst, batch_size=self.sample_batch_size, shuffle=(
153 | train_sampler is None),
154 | num_workers=4, pin_memory=True, sampler=train_sampler)
155 | self.data_iter = DataIter(loader)
156 | print("sample_batch_size:", self.sample_batch_size)
157 | return {"synthetic": best_inputs, "targets": targets}
158 |
159 | def sample(self):
160 | return self.data_iter.next()
161 |
--------------------------------------------------------------------------------
/datafree/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from ._utils import *
2 | from .logger import get_logger
3 |
4 | from . import sync_transforms, inception
--------------------------------------------------------------------------------
/datafree/utils/_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import ConcatDataset, Dataset
3 | import numpy as np
4 | from PIL import Image
5 | import os, random, math
6 | from copy import deepcopy
7 | from contextlib import contextmanager
8 |
9 | def get_pseudo_label(n_or_label, num_classes, device, onehot=False):
10 | if isinstance(n_or_label, int):
11 | label = torch.randint(0, num_classes, size=(n_or_label,), device=device)
12 | else:
13 | label = n_or_label.to(device)
14 | if onehot:
15 | label = torch.zeros(len(label), num_classes, device=device).scatter_(1, label.unsqueeze(1), 1.)
16 | return label
17 |
18 | def pdist(sample_1, sample_2, norm=2, eps=1e-5):
19 | r"""Compute the matrix of all squared pairwise distances.
20 | Arguments
21 | ---------
22 | sample_1 : torch.Tensor or Variable
23 | The first sample, should be of shape ``(n_1, d)``.
24 | sample_2 : torch.Tensor or Variable
25 | The second sample, should be of shape ``(n_2, d)``.
26 | norm : float
27 | The l_p norm to be used.
28 | Returns
29 | -------
30 | torch.Tensor or Variable
31 | Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
32 | ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""
33 | n_1, n_2 = sample_1.size(0), sample_2.size(0)
34 | norm = float(norm)
35 | if norm == 2.:
36 | norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
37 | norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
38 | norms = (norms_1.expand(n_1, n_2) +
39 | norms_2.transpose(0, 1).expand(n_1, n_2))
40 | distances_squared = norms - 2 * sample_1.mm(sample_2.t())
41 | return torch.sqrt(eps + torch.abs(distances_squared))
42 | else:
43 | dim = sample_1.size(1)
44 | expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim)
45 | expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim)
46 | differences = torch.abs(expanded_1 - expanded_2) ** norm
47 | inner = torch.sum(differences, dim=2, keepdim=False)
48 | return (eps + inner) ** (1. / norm)
49 |
50 | class MemoryBank(object):
51 | def __init__(self, device, max_size=4096, dim_feat=512):
52 | self.device = device
53 | self.data = torch.randn( max_size, dim_feat ).to(device)
54 | self._ptr = 0
55 | self.n_updates = 0
56 |
57 | self.max_size = max_size
58 | self.dim_feat = dim_feat
59 |
60 | def add(self, feat):
61 | n, c = feat.shape
62 | assert self.dim_feat==c and self.max_size % n==0, "%d, %d"%(dim_feat, c, max_size, n)
63 | self.data[self._ptr:self._ptr+n] = feat.detach()
64 | self._ptr = (self._ptr+n) % (self.max_size)
65 | self.n_updates+=n
66 |
67 | def get_data(self, k=None, index=None):
68 | if k is None:
69 | k = self.max_size
70 | assert k <= self.max_size
71 |
72 | if self.n_updates>self.max_size:
73 | if index is None:
74 | index = random.sample(list(range(self.max_size)), k=k)
75 | return self.data[index], index
76 | else:
77 | if index is None:
78 | index = random.sample(list(range(self._ptr)), k=min(k, self._ptr))
79 | return self.data[index], index
80 |
81 | def clip_images(image_tensor, mean, std):
82 | mean = np.array(mean)
83 | std = np.array(std)
84 | for c in range(3):
85 | m, s = mean[c], std[c]
86 | image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s)
87 | return image_tensor
88 |
89 | def save_image_batch(imgs, output, col=None, size=None, pack=True):
90 | if isinstance(imgs, torch.Tensor):
91 | imgs = (imgs.detach().clamp(0, 1).cpu().numpy()*255).astype('uint8')
92 | base_dir = os.path.dirname(output)
93 | if base_dir!='':
94 | os.makedirs(base_dir, exist_ok=True)
95 | if pack:
96 | imgs = pack_images( imgs, col=col ).transpose( 1, 2, 0 ).squeeze()
97 | imgs = Image.fromarray( imgs )
98 | if size is not None:
99 | if isinstance(size, (list,tuple)):
100 | imgs = imgs.resize(size)
101 | else:
102 | w, h = imgs.size
103 | max_side = max( h, w )
104 | scale = float(size) / float(max_side)
105 | _w, _h = int(w*scale), int(h*scale)
106 | imgs = imgs.resize([_w, _h])
107 | imgs.save(output)
108 | else:
109 | output_filename = output.strip('.png')
110 | for idx, img in enumerate(imgs):
111 | img = Image.fromarray( img.transpose(1, 2, 0) )
112 | img.save(output_filename+'-%d.png'%(idx))
113 |
114 | def pack_images(images, col=None, channel_last=False, padding=1):
115 | # N, C, H, W
116 | if isinstance(images, (list, tuple) ):
117 | images = np.stack(images, 0)
118 | if channel_last:
119 | images = images.transpose(0,3,1,2) # make it channel first
120 | assert len(images.shape)==4
121 | assert isinstance(images, np.ndarray)
122 |
123 | N,C,H,W = images.shape
124 | if col is None:
125 | col = int(math.ceil(math.sqrt(N)))
126 | row = int(math.ceil(N / col))
127 |
128 | pack = np.zeros( (C, H*row+padding*(row-1), W*col+padding*(col-1)), dtype=images.dtype )
129 | for idx, img in enumerate(images):
130 | h = (idx // col) * (H+padding)
131 | w = (idx % col) * (W+padding)
132 | pack[:, h:h+H, w:w+W] = img
133 | return pack
134 |
135 | def flatten_dict(dic):
136 | flattned = dict()
137 | def _flatten(prefix, d):
138 | for k, v in d.items():
139 | if isinstance(v, dict):
140 | if prefix is None:
141 | _flatten( k, v )
142 | else:
143 | _flatten( prefix+'/%s'%k, v )
144 | else:
145 | if prefix is None:
146 | flattned[k] = v
147 | else:
148 | flattned[ prefix+'/%s'%k ] = v
149 |
150 | _flatten(None, dic)
151 | return flattned
152 |
153 | def normalize(tensor, mean, std, reverse=False):
154 | if reverse:
155 | _mean = [ -m / s for m, s in zip(mean, std) ]
156 | _std = [ 1/s for s in std ]
157 | else:
158 | _mean = mean
159 | _std = std
160 |
161 | _mean = torch.as_tensor(_mean, dtype=tensor.dtype, device=tensor.device)
162 | _std = torch.as_tensor(_std, dtype=tensor.dtype, device=tensor.device)
163 | tensor = (tensor - _mean[None, :, None, None]) / (_std[None, :, None, None])
164 | return tensor
165 |
166 | class Normalizer(object):
167 | def __init__(self, mean, std):
168 | self.mean = mean
169 | self.std = std
170 |
171 | def __call__(self, x, reverse=False):
172 | return normalize(x, self.mean, self.std, reverse=reverse)
173 |
174 | def load_yaml(filepath):
175 | yaml=YAML()
176 | with open(filepath, 'r') as f:
177 | return yaml.load(f)
178 |
179 | def _collect_all_images(root, postfix=['png', 'jpg', 'jpeg', 'JPEG']):
180 | images = []
181 | if isinstance( postfix, str):
182 | postfix = [ postfix ]
183 | for dirpath, dirnames, files in os.walk(root):
184 | for pos in postfix:
185 | for f in files:
186 | if f.endswith( pos ):
187 | images.append( os.path.join( dirpath, f ) )
188 | return images
189 |
190 | class UnlabeledImageDataset(torch.utils.data.Dataset):
191 | def __init__(self, root, transform=None, pair_sample=False):
192 | self.root = os.path.abspath(root)
193 | self.images = _collect_all_images(self.root) #[ os.path.join(self.root, f) for f in os.listdir( root ) ]
194 | self.transform = transform
195 | self.pair_sample = pair_sample
196 |
197 | def __getitem__(self, idx):
198 | img = Image.open( self.images[idx] )
199 | if self.transform:
200 | img1 = self.transform(img)
201 | if not self.pair_sample:
202 | return img1
203 | img2 = self.transform(img)
204 | return img1, img2
205 |
206 | def __len__(self):
207 | return len(self.images)
208 |
209 | def __repr__(self):
210 | return 'Unlabeled data:\n\troot: %s\n\tdata mount: %d\n\ttransforms: %s'%(self.root, len(self), self.transform)
211 |
212 | class LabeledImageDataset(torch.utils.data.Dataset):
213 | def __init__(self, root, transform=None):
214 | self.root = os.path.abspath(root)
215 | self.categories = [ int(f) for f in os.listdir( root ) ]
216 | images = []
217 | targets = []
218 | for c in self.categories:
219 | category_dir = os.path.join( self.root, str(c))
220 | _images = [ os.path.join( category_dir, f ) for f in os.listdir(category_dir) ]
221 | images.extend(_images)
222 | targets.extend([c for _ in range(len(_images))])
223 | self.images = images
224 | self.targets = targets
225 | self.transform = transform
226 | def __getitem__(self, idx):
227 | img, target = Image.open( self.images[idx] ), self.targets[idx]
228 | if self.transform:
229 | img = self.transform(img)
230 | return img, target
231 | def __len__(self):
232 | return len(self.images)
233 |
234 | class ImagePool(object):
235 | def __init__(self, root):
236 | self.root = os.path.abspath(root)
237 | os.makedirs(self.root, exist_ok=True)
238 | self._idx = 0
239 |
240 | def add(self, imgs, targets=None):
241 | save_image_batch(imgs, os.path.join( self.root, "%d.png"%(self._idx) ), pack=False)
242 | self._idx+=1
243 |
244 | def get_dataset(self, transform=None, pair_sample=False):
245 | return UnlabeledImageDataset(self.root, transform=transform, pair_sample=pair_sample)
246 |
247 | class DataIter(object):
248 | def __init__(self, dataloader):
249 | self.dataloader = dataloader
250 | self._iter = iter(self.dataloader)
251 |
252 | def next(self):
253 | try:
254 | data = next( self._iter )
255 | except StopIteration:
256 | self._iter = iter(self.dataloader)
257 | data = next( self._iter )
258 | return data
259 |
260 | @contextmanager
261 | def dummy_ctx(*args, **kwds):
262 | try:
263 | yield None
264 | finally:
265 | pass
--------------------------------------------------------------------------------
/datafree/utils/fmix.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import numpy as np
5 | from scipy.stats import beta
6 |
7 |
8 | def fftfreqnd(h, w=None, z=None):
9 | """ Get bin values for discrete fourier transform of size (h, w, z)
10 | :param h: Required, first dimension size
11 | :param w: Optional, second dimension size
12 | :param z: Optional, third dimension size
13 | """
14 | fz = fx = 0
15 | fy = np.fft.fftfreq(h)
16 |
17 | if w is not None:
18 | fy = np.expand_dims(fy, -1)
19 |
20 | if w % 2 == 1:
21 | fx = np.fft.fftfreq(w)[: w // 2 + 2]
22 | else:
23 | fx = np.fft.fftfreq(w)[: w // 2 + 1]
24 |
25 | if z is not None:
26 | fy = np.expand_dims(fy, -1)
27 | if z % 2 == 1:
28 | fz = np.fft.fftfreq(z)[:, None]
29 | else:
30 | fz = np.fft.fftfreq(z)[:, None]
31 |
32 | return np.sqrt(fx * fx + fy * fy + fz * fz)
33 |
34 |
35 | def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
36 | """ Samples a fourier image with given size and frequencies decayed by decay power
37 | :param freqs: Bin values for the discrete fourier transform
38 | :param decay_power: Decay power for frequency decay prop 1/f**d
39 | :param ch: Number of channels for the resulting mask
40 | :param h: Required, first dimension size
41 | :param w: Optional, second dimension size
42 | :param z: Optional, third dimension size
43 | """
44 | scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power)
45 |
46 | param_size = [ch] + list(freqs.shape) + [2]
47 | param = np.random.randn(*param_size)
48 |
49 | scale = np.expand_dims(scale, -1)[None, :]
50 |
51 | return scale * param
52 |
53 |
54 | def make_low_freq_image(decay, shape, ch=1):
55 | """ Sample a low frequency image from fourier space
56 | :param decay_power: Decay power for frequency decay prop 1/f**d
57 | :param shape: Shape of desired mask, list up to 3 dims
58 | :param ch: Number of channels for desired mask
59 | """
60 | freqs = fftfreqnd(*shape)
61 | spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1))
62 | spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
63 | mask = np.real(np.fft.irfftn(spectrum, shape))
64 |
65 | if len(shape) == 1:
66 | mask = mask[:1, :shape[0]]
67 | if len(shape) == 2:
68 | mask = mask[:1, :shape[0], :shape[1]]
69 | if len(shape) == 3:
70 | mask = mask[:1, :shape[0], :shape[1], :shape[2]]
71 |
72 | mask = mask
73 | mask = (mask - mask.min())
74 | mask = mask / mask.max()
75 | return mask
76 |
77 |
78 | def sample_lam(alpha, reformulate=False):
79 | """ Sample a lambda from symmetric beta distribution with given alpha
80 | :param alpha: Alpha value for beta distribution
81 | :param reformulate: If True, uses the reformulation of [1].
82 | """
83 | if reformulate:
84 | lam = beta.rvs(alpha+1, alpha)
85 | else:
86 | lam = beta.rvs(alpha, alpha)
87 |
88 | return lam
89 |
90 |
91 | def binarise_mask(mask, lam, in_shape, max_soft=0.0):
92 | """ Binarises a given low frequency image such that it has mean lambda.
93 | :param mask: Low frequency image, usually the result of `make_low_freq_image`
94 | :param lam: Mean value of final mask
95 | :param in_shape: Shape of inputs
96 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
97 | :return:
98 | """
99 | idx = mask.reshape(-1).argsort()[::-1]
100 | mask = mask.reshape(-1)
101 | num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size)
102 |
103 | eff_soft = max_soft
104 | if max_soft > lam or max_soft > (1-lam):
105 | eff_soft = min(lam, 1-lam)
106 |
107 | soft = int(mask.size * eff_soft)
108 | num_low = num - soft
109 | num_high = num + soft
110 |
111 | mask[idx[:num_high]] = 1
112 | mask[idx[num_low:]] = 0
113 | mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))
114 |
115 | mask = mask.reshape((1, *in_shape))
116 | return mask
117 |
118 |
119 | def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False):
120 | """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
121 | it based on this lambda
122 | :param alpha: Alpha value for beta distribution from which to sample mean of mask
123 | :param decay_power: Decay power for frequency decay prop 1/f**d
124 | :param shape: Shape of desired mask, list up to 3 dims
125 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
126 | :param reformulate: If True, uses the reformulation of [1].
127 | """
128 | if isinstance(shape, int):
129 | shape = (shape,)
130 |
131 | # Choose lambda
132 | lam = sample_lam(alpha, reformulate)
133 |
134 | # Make mask, get mean / std
135 | mask = make_low_freq_image(decay_power, shape)
136 | mask = binarise_mask(mask, lam, shape, max_soft)
137 |
138 | return lam, mask
139 |
140 |
141 | def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
142 | """
143 | :param x: Image batch on which to apply fmix of shape [b, c, shape*]
144 | :param alpha: Alpha value for beta distribution from which to sample mean of mask
145 | :param decay_power: Decay power for frequency decay prop 1/f**d
146 | :param shape: Shape of desired mask, list up to 3 dims
147 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
148 | :param reformulate: If True, uses the reformulation of [1].
149 | :return: mixed input, permutation indices, lambda value of mix,
150 | """
151 | lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
152 | index = np.random.permutation(x.shape[0])
153 |
154 | x1, x2 = x * mask, x[index] * (1-mask)
155 | return x1+x2, index, lam
156 |
157 |
158 | class FMixBase:
159 | r""" FMix augmentation
160 | Args:
161 | decay_power (float): Decay power for frequency decay prop 1/f**d
162 | alpha (float): Alpha value for beta distribution from which to sample mean of mask
163 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
164 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
165 | reformulate (bool): If True, uses the reformulation of [1].
166 | """
167 |
168 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
169 | super().__init__()
170 | self.decay_power = decay_power
171 | self.reformulate = reformulate
172 | self.size = size
173 | self.alpha = alpha
174 | self.max_soft = max_soft
175 | self.index = None
176 | self.lam = None
177 |
178 | def __call__(self, x):
179 | raise NotImplementedError
180 |
181 | def loss(self, *args, **kwargs):
182 | raise NotImplementedError
--------------------------------------------------------------------------------
/datafree/utils/inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 |
6 | try:
7 | from torchvision.models.utils import load_state_dict_from_url
8 | except ImportError:
9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
10 |
11 | # Inception weights ported to Pytorch from
12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14 |
15 |
16 | class InceptionV3(nn.Module):
17 | """Pretrained InceptionV3 network returning feature maps"""
18 |
19 | # Index of default block of inception to return,
20 | # corresponds to output of final average pooling
21 | DEFAULT_BLOCK_INDEX = 3
22 |
23 | # Maps feature dimensionality to their output blocks indices
24 | BLOCK_INDEX_BY_DIM = {
25 | 64: 0, # First max pooling features
26 | 192: 1, # Second max pooling featurs
27 | 768: 2, # Pre-aux classifier features
28 | 2048: 3 # Final average pooling features
29 | }
30 |
31 | def __init__(self,
32 | output_blocks=(DEFAULT_BLOCK_INDEX,),
33 | resize_input=True,
34 | normalize_input=True,
35 | requires_grad=False,
36 | use_fid_inception=True):
37 | """Build pretrained InceptionV3
38 | Parameters
39 | ----------
40 | output_blocks : list of int
41 | Indices of blocks to return features of. Possible values are:
42 | - 0: corresponds to output of first max pooling
43 | - 1: corresponds to output of second max pooling
44 | - 2: corresponds to output which is fed to aux classifier
45 | - 3: corresponds to output of final average pooling
46 | resize_input : bool
47 | If true, bilinearly resizes input to width and height 299 before
48 | feeding input to model. As the network without fully connected
49 | layers is fully convolutional, it should be able to handle inputs
50 | of arbitrary size, so resizing might not be strictly needed
51 | normalize_input : bool
52 | If true, scales the input from range (0, 1) to the range the
53 | pretrained Inception network expects, namely (-1, 1)
54 | requires_grad : bool
55 | If true, parameters of the model require gradients. Possibly useful
56 | for finetuning the network
57 | use_fid_inception : bool
58 | If true, uses the pretrained Inception model used in Tensorflow's
59 | FID implementation. If false, uses the pretrained Inception model
60 | available in torchvision. The FID Inception model has different
61 | weights and a slightly different structure from torchvision's
62 | Inception model. If you want to compute FID scores, you are
63 | strongly advised to set this parameter to true to get comparable
64 | results.
65 | """
66 | super(InceptionV3, self).__init__()
67 |
68 | self.resize_input = resize_input
69 | self.normalize_input = normalize_input
70 | self.output_blocks = sorted(output_blocks)
71 | self.last_needed_block = max(output_blocks)
72 |
73 | assert self.last_needed_block <= 3, \
74 | 'Last possible output block index is 3'
75 |
76 | self.blocks = nn.ModuleList()
77 |
78 | if use_fid_inception:
79 | inception = fid_inception_v3()
80 | else:
81 | inception = _inception_v3(pretrained=True)
82 |
83 | # Block 0: input to maxpool1
84 | block0 = [
85 | inception.Conv2d_1a_3x3,
86 | inception.Conv2d_2a_3x3,
87 | inception.Conv2d_2b_3x3,
88 | nn.MaxPool2d(kernel_size=3, stride=2)
89 | ]
90 | self.blocks.append(nn.Sequential(*block0))
91 |
92 | # Block 1: maxpool1 to maxpool2
93 | if self.last_needed_block >= 1:
94 | block1 = [
95 | inception.Conv2d_3b_1x1,
96 | inception.Conv2d_4a_3x3,
97 | nn.MaxPool2d(kernel_size=3, stride=2)
98 | ]
99 | self.blocks.append(nn.Sequential(*block1))
100 |
101 | # Block 2: maxpool2 to aux classifier
102 | if self.last_needed_block >= 2:
103 | block2 = [
104 | inception.Mixed_5b,
105 | inception.Mixed_5c,
106 | inception.Mixed_5d,
107 | inception.Mixed_6a,
108 | inception.Mixed_6b,
109 | inception.Mixed_6c,
110 | inception.Mixed_6d,
111 | inception.Mixed_6e,
112 | ]
113 | self.blocks.append(nn.Sequential(*block2))
114 |
115 | # Block 3: aux classifier to final avgpool
116 | if self.last_needed_block >= 3:
117 | block3 = [
118 | inception.Mixed_7a,
119 | inception.Mixed_7b,
120 | inception.Mixed_7c,
121 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
122 | ]
123 | self.blocks.append(nn.Sequential(*block3))
124 |
125 | for param in self.parameters():
126 | param.requires_grad = requires_grad
127 |
128 | def forward(self, inp):
129 | """Get Inception feature maps
130 | Parameters
131 | ----------
132 | inp : torch.autograd.Variable
133 | Input tensor of shape Bx3xHxW. Values are expected to be in
134 | range (0, 1)
135 | Returns
136 | -------
137 | List of torch.autograd.Variable, corresponding to the selected output
138 | block, sorted ascending by index
139 | """
140 | outp = []
141 | x = inp
142 |
143 | if self.resize_input:
144 | x = F.interpolate(x,
145 | size=(299, 299),
146 | mode='bilinear',
147 | align_corners=False)
148 |
149 | if self.normalize_input:
150 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
151 |
152 | for idx, block in enumerate(self.blocks):
153 | x = block(x)
154 | if idx in self.output_blocks:
155 | outp.append(x)
156 |
157 | if idx == self.last_needed_block:
158 | break
159 |
160 | return outp
161 |
162 |
163 | def _inception_v3(*args, **kwargs):
164 | """Wraps `torchvision.models.inception_v3`
165 | Skips default weight inititialization if supported by torchvision version.
166 | See https://github.com/mseitzer/pytorch-fid/issues/28.
167 | """
168 | try:
169 | version = tuple(map(int, torchvision.__version__.split('.')[:2]))
170 | except ValueError:
171 | # Just a caution against weird version strings
172 | version = (0,)
173 |
174 | if version >= (0, 6):
175 | kwargs['init_weights'] = False
176 |
177 | return torchvision.models.inception_v3(*args, **kwargs)
178 |
179 |
180 | def fid_inception_v3():
181 | """Build pretrained Inception model for FID computation
182 | The Inception model for FID computation uses a different set of weights
183 | and has a slightly different structure than torchvision's Inception.
184 | This method first constructs torchvision's Inception and then patches the
185 | necessary parts that are different in the FID Inception model.
186 | """
187 | inception = _inception_v3(num_classes=1008,
188 | aux_logits=False,
189 | pretrained=False)
190 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
191 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
192 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
193 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
194 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
195 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
196 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
197 | inception.Mixed_7b = FIDInceptionE_1(1280)
198 | inception.Mixed_7c = FIDInceptionE_2(2048)
199 |
200 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
201 | inception.load_state_dict(state_dict)
202 | return inception
203 |
204 |
205 | class FIDInceptionA(torchvision.models.inception.InceptionA):
206 | """InceptionA block patched for FID computation"""
207 | def __init__(self, in_channels, pool_features):
208 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
209 |
210 | def forward(self, x):
211 | branch1x1 = self.branch1x1(x)
212 |
213 | branch5x5 = self.branch5x5_1(x)
214 | branch5x5 = self.branch5x5_2(branch5x5)
215 |
216 | branch3x3dbl = self.branch3x3dbl_1(x)
217 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
218 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
219 |
220 | # Patch: Tensorflow's average pool does not use the padded zero's in
221 | # its average calculation
222 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
223 | count_include_pad=False)
224 | branch_pool = self.branch_pool(branch_pool)
225 |
226 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
227 | return torch.cat(outputs, 1)
228 |
229 |
230 | class FIDInceptionC(torchvision.models.inception.InceptionC):
231 | """InceptionC block patched for FID computation"""
232 | def __init__(self, in_channels, channels_7x7):
233 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
234 |
235 | def forward(self, x):
236 | branch1x1 = self.branch1x1(x)
237 |
238 | branch7x7 = self.branch7x7_1(x)
239 | branch7x7 = self.branch7x7_2(branch7x7)
240 | branch7x7 = self.branch7x7_3(branch7x7)
241 |
242 | branch7x7dbl = self.branch7x7dbl_1(x)
243 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
244 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
245 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
246 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
247 |
248 | # Patch: Tensorflow's average pool does not use the padded zero's in
249 | # its average calculation
250 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
251 | count_include_pad=False)
252 | branch_pool = self.branch_pool(branch_pool)
253 |
254 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
255 | return torch.cat(outputs, 1)
256 |
257 |
258 | class FIDInceptionE_1(torchvision.models.inception.InceptionE):
259 | """First InceptionE block patched for FID computation"""
260 | def __init__(self, in_channels):
261 | super(FIDInceptionE_1, self).__init__(in_channels)
262 |
263 | def forward(self, x):
264 | branch1x1 = self.branch1x1(x)
265 |
266 | branch3x3 = self.branch3x3_1(x)
267 | branch3x3 = [
268 | self.branch3x3_2a(branch3x3),
269 | self.branch3x3_2b(branch3x3),
270 | ]
271 | branch3x3 = torch.cat(branch3x3, 1)
272 |
273 | branch3x3dbl = self.branch3x3dbl_1(x)
274 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
275 | branch3x3dbl = [
276 | self.branch3x3dbl_3a(branch3x3dbl),
277 | self.branch3x3dbl_3b(branch3x3dbl),
278 | ]
279 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
280 |
281 | # Patch: Tensorflow's average pool does not use the padded zero's in
282 | # its average calculation
283 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
284 | count_include_pad=False)
285 | branch_pool = self.branch_pool(branch_pool)
286 |
287 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
288 | return torch.cat(outputs, 1)
289 |
290 |
291 | class FIDInceptionE_2(torchvision.models.inception.InceptionE):
292 | """Second InceptionE block patched for FID computation"""
293 | def __init__(self, in_channels):
294 | super(FIDInceptionE_2, self).__init__(in_channels)
295 |
296 | def forward(self, x):
297 | branch1x1 = self.branch1x1(x)
298 |
299 | branch3x3 = self.branch3x3_1(x)
300 | branch3x3 = [
301 | self.branch3x3_2a(branch3x3),
302 | self.branch3x3_2b(branch3x3),
303 | ]
304 | branch3x3 = torch.cat(branch3x3, 1)
305 |
306 | branch3x3dbl = self.branch3x3dbl_1(x)
307 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
308 | branch3x3dbl = [
309 | self.branch3x3dbl_3a(branch3x3dbl),
310 | self.branch3x3dbl_3b(branch3x3dbl),
311 | ]
312 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
313 |
314 | # Patch: The FID Inception model uses max pooling instead of average
315 | # pooling. This is likely an error in this specific Inception
316 | # implementation, as other Inception models use average pooling here
317 | # (which matches the description in the paper).
318 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
319 | branch_pool = self.branch_pool(branch_pool)
320 |
321 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
322 | return torch.cat(outputs, 1)
--------------------------------------------------------------------------------
/datafree/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os, sys
3 | from termcolor import colored
4 |
5 |
6 | class _ColorfulFormatter(logging.Formatter):
7 | def __init__(self, *args, **kwargs):
8 | super(_ColorfulFormatter, self).__init__(*args, **kwargs)
9 |
10 | def formatMessage(self, record):
11 | log = super(_ColorfulFormatter, self).formatMessage(record)
12 |
13 | if record.levelno == logging.WARNING:
14 | prefix = colored("WARNING", "yellow", attrs=["blink"])
15 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
16 | prefix = colored("ERROR", "red", attrs=["blink", "underline"])
17 | else:
18 | return log
19 |
20 | return prefix + " " + log
21 |
22 | def get_logger(name='train', output=None, color=True):
23 | logger = logging.getLogger(name)
24 | logger.setLevel(logging.DEBUG)
25 | logger.propagate = False
26 |
27 | # STDOUT
28 | stdout_handler = logging.StreamHandler( stream=sys.stdout )
29 | stdout_handler.setLevel( logging.DEBUG )
30 |
31 | plain_formatter = logging.Formatter(
32 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" )
33 | if color:
34 | formatter = _ColorfulFormatter(
35 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
36 | datefmt="%m/%d %H:%M:%S")
37 | else:
38 | formatter = plain_formatter
39 | stdout_handler.setFormatter(formatter)
40 |
41 | logger.addHandler(stdout_handler)
42 |
43 | # FILE
44 | if output is not None:
45 | if output.endswith('.txt') or output.endswith('.log'):
46 | os.makedirs(os.path.dirname(output), exist_ok=True)
47 | filename = output
48 | else:
49 | os.makedirs(output, exist_ok=True)
50 | filename = os.path.join(output, "log.txt")
51 | file_handler = logging.FileHandler(filename)
52 | file_handler.setFormatter(plain_formatter)
53 | file_handler.setLevel(logging.DEBUG)
54 | logger.addHandler(file_handler)
55 | return logger
56 |
57 |
58 |
--------------------------------------------------------------------------------
/datafree/utils/pair.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision
3 | from torchvision import datasets
4 | from collections import defaultdict
5 | from torch.utils.data import Sampler, Dataset
6 |
7 | class DatasetWrapper(Dataset):
8 | # Additinoal attributes
9 | # - indices
10 | # - classwise_indices
11 | # - num_classes
12 | # - get_class
13 |
14 | def __init__(self, dataset, indices=None):
15 | self.base_dataset = dataset
16 | if indices is None:
17 | self.indices = list(range(len(dataset)))
18 | else:
19 | self.indices = indices
20 |
21 | # torchvision 0.2.0 compatibility
22 | if torchvision.__version__.startswith('0.2'):
23 | if isinstance(self.base_dataset, datasets.ImageFolder):
24 | self.base_dataset.targets = [s[1] for s in self.base_dataset.imgs]
25 | else:
26 | if self.base_dataset.train:
27 | self.base_dataset.targets = self.base_dataset.train_labels
28 | else:
29 | self.base_dataset.targets = self.base_dataset.test_labels
30 |
31 | self.classwise_indices = defaultdict(list)
32 | for i in range(len(self)):
33 | y = self.base_dataset.targets[self.indices[i]]
34 | self.classwise_indices[y].append(i)
35 | self.num_classes = max(self.classwise_indices.keys())+1
36 |
37 | def __getitem__(self, i):
38 | return self.base_dataset[self.indices[i]]
39 |
40 | def __len__(self):
41 | return len(self.indices)
42 |
43 | def get_class(self, i):
44 | return self.base_dataset.targets[self.indices[i]]
45 |
46 |
47 | class PairBatchSampler(Sampler):
48 | def __init__(self, dataset, batch_size, num_iterations=None):
49 | self.dataset = dataset
50 | self.batch_size = batch_size
51 | self.num_iterations = num_iterations
52 |
53 | def __iter__(self):
54 | indices = list(range(len(self.dataset)))
55 | random.shuffle(indices)
56 | for k in range(len(self)):
57 | if self.num_iterations is None:
58 | offset = k*self.batch_size
59 | batch_indices = indices[offset:offset+self.batch_size]
60 | else:
61 | batch_indices = random.sample(range(len(self.dataset)),
62 | self.batch_size)
63 |
64 | pair_indices = []
65 | for idx in batch_indices:
66 | y = self.dataset.get_class(idx)
67 | pair_indices.append(random.choice(self.dataset.classwise_indices[y]))
68 |
69 | yield batch_indices + pair_indices
70 |
71 | def __len__(self):
72 | if self.num_iterations is None:
73 | return (len(self.dataset)+self.batch_size-1) // self.batch_size
74 | else:
75 | return self.num_iterations
--------------------------------------------------------------------------------
/datafree/utils/sync_transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .transforms import *
--------------------------------------------------------------------------------
/datafree/utils/vis.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('agg')
3 | import matplotlib.pyplot as plt
4 | from matplotlib import cm
5 |
6 | from sklearn.manifold import TSNE
7 | import seaborn as sns
8 | import numpy as np
9 | import pandas as pd
10 | import os
11 | import io
12 |
13 | def tsne_features( real_features, real_labels, fake_features, fake_labels, output_file ):
14 | fig = plt.figure(figsize=(10, 10))
15 | features = np.concatenate( [real_features, fake_features ], axis=0 )
16 | labels = np.concatenate( [ real_labels, fake_labels ], axis=0 ).reshape(-1, 1)
17 | tsne = TSNE( n_components=2, perplexity=10 ).fit_transform( features )
18 | df = np.concatenate( [tsne, labels], axis=1 )
19 | df = pd.DataFrame(df, columns=["x", "y", "label"])
20 | style = [ 'real' for _ in range(len(real_features)) ] + [ 'fake' for _ in range(len(fake_features))]
21 | sns.scatterplot( x="x", y="y", data=df, hue="label", palette=sns.color_palette("dark", 10), s=50, style=style)
22 | if output_file is not None:
23 | dirname = os.path.dirname( output_file )
24 | if dirname!='':
25 | os.makedirs( dirname, exist_ok=True )
26 | plt.savefig( output_file )
27 | else:
28 | fig.canvas.draw()
29 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
30 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
31 | plt.close()
32 | return img
33 |
34 |
35 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | BIG_NUMBER = 1e12
6 |
7 |
8 | def pdist(e, squared=False, eps=1e-12):
9 | e_square = e.pow(2).sum(dim=1)
10 | prod = e @ e.t()
11 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
12 |
13 | if not squared:
14 | res = res.sqrt()
15 |
16 | res = res.clone()
17 | res[range(len(e)), range(len(e))] = 0
18 | return res
19 |
20 |
21 | def pos_neg_mask(labels):
22 | pos_mask = (labels.unsqueeze(0) == labels.unsqueeze(1)) * \
23 | (1 - torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device))
24 | neg_mask = (labels.unsqueeze(0) != labels.unsqueeze(1)) * \
25 | (1 - torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device))
26 |
27 | return pos_mask, neg_mask
28 |
29 | class BalancedWeighted(nn.Module):
30 | cut_off = 0.5
31 | nonzero_loss_cutoff = 1.0
32 | """
33 | Distance Weighted loss assume that embeddings are normalized py 2-norm.
34 | """
35 |
36 | def __init__(self, dist_func=pdist):
37 | self.dist_func = dist_func
38 | super().__init__()
39 |
40 | def forward(self, embeddings, labels):
41 | with torch.no_grad():
42 | embeddings = F.normalize(embeddings, dim=1, p=2)
43 | pos_mask, neg_mask = pos_neg_mask(labels)
44 | pos_pair_idx = pos_mask.nonzero()
45 | anchor_idx = pos_pair_idx[:, 0]
46 | pos_idx = pos_pair_idx[:, 1]
47 |
48 | d = embeddings.size(1)
49 | dist = (pdist(embeddings, squared=True) + torch.eye(embeddings.size(0),
50 | device=embeddings.device, dtype=torch.float32)).sqrt()
51 | log_weight = ((2.0 - d) * dist.log() - ((d - 3.0)/2.0)
52 | * (1.0 - 0.25 * (dist * dist)).log())
53 | weight = (log_weight - log_weight.max(dim=1,
54 | keepdim=True)[0]).exp()
55 | weight = weight * \
56 | (neg_mask * (dist < self.nonzero_loss_cutoff) * (dist > self.cut_off)).float()
57 |
58 | weight = weight + \
59 | ((weight.sum(dim=1, keepdim=True) == 0) * neg_mask).float()
60 | weight = weight / (weight.sum(dim=1, keepdim=True))
61 | weight = weight[anchor_idx]
62 | if torch.max(weight) == 0:
63 | return None, None, None
64 | weight[torch.isnan(weight)] = 0
65 | neg_idx = torch.multinomial(weight, 1).squeeze(1)
66 |
67 | return anchor_idx, pos_idx, neg_idx
68 |
69 |
70 | class DistanceWeighted(nn.Module):
71 | cut_off = 0.5
72 | nonzero_loss_cutoff = 1.4
73 | """
74 | Distance Weighted loss assume that embeddings are normalized py 2-norm.
75 | """
76 |
77 | def __init__(self, dist_func=pdist):
78 | self.dist_func = dist_func
79 | super().__init__()
80 |
81 | def forward(self, embeddings, labels):
82 | with torch.no_grad():
83 | embeddings = F.normalize(embeddings, dim=1, p=2)
84 | pos_mask, neg_mask = pos_neg_mask(labels)
85 | pos_pair_idx = pos_mask.nonzero()
86 | anchor_idx = pos_pair_idx[:, 0]
87 | pos_idx = pos_pair_idx[:, 1]
88 |
89 | d = embeddings.size(1)
90 | dist = (pdist(embeddings, squared=True) + torch.eye(embeddings.size(0),
91 | device=embeddings.device, dtype=torch.float32)).sqrt()
92 | dist = dist.clamp(min=self.cut_off)
93 | log_weight = ((2.0 - d) * dist.log() - ((d - 3.0)/2.0)
94 | * (1.0 - 0.25 * (dist * dist)).log())
95 | weight = (log_weight - log_weight.max(dim=1,
96 | keepdim=True)[0]).exp()
97 | weight = weight * \
98 | (neg_mask * (dist < self.nonzero_loss_cutoff)).float()
99 |
100 | weight = weight + \
101 | ((weight.sum(dim=1, keepdim=True) == 0) * neg_mask).float()
102 | weight = weight / (weight.sum(dim=1, keepdim=True))
103 | weight = weight[anchor_idx]
104 | if torch.max(weight) == 0:
105 | return None, None, None
106 | weight[torch.isnan(weight)] = 0
107 | neg_idx = torch.multinomial(weight, 1).squeeze(1)
108 |
109 | return anchor_idx, pos_idx, neg_idx
110 |
111 |
112 | class TriLoss(nn.Module):
113 | def __init__(self, p=2, margin=0.2, balanced_sampling=False):
114 | super().__init__()
115 | self.p = p
116 | self.margin = margin
117 |
118 | # update distance function accordingly
119 | if balanced_sampling:
120 | self.sampler = BalancedWeighted()
121 | else:
122 | self.sampler = DistanceWeighted()
123 | self.sampler.dist_func = lambda e: pdist(e, squared=(p == 2))
124 | self.count = 0
125 |
126 | def forward(self, stu_features, logits, labels, negative=False):
127 | if negative:
128 | anchor_idx, neg_idx, pos_idx = self.sampler(logits, labels)
129 | else:
130 | anchor_idx, pos_idx, neg_idx = self.sampler(logits, labels)
131 |
132 | loss = 0.
133 | if anchor_idx is None:
134 | print('warning: no negative samples found.')
135 | return torch.zeros(1)
136 | self.count += 1
137 | for embeddings in stu_features:
138 | if len(embeddings.shape) > 2:
139 | embeddings = embeddings.mean(dim=(2, 3), keepdim=False)
140 | embeddings = F.normalize(embeddings, p=2, dim=-1)
141 | anchor_embed = embeddings[anchor_idx]
142 | positive_embed = embeddings[pos_idx]
143 | negative_embed = embeddings[neg_idx]
144 |
145 | triloss = F.triplet_margin_loss(anchor_embed, positive_embed, negative_embed,
146 | margin=self.margin, p=self.p, reduction='none')
147 | loss += triloss
148 |
149 | return loss.mean()
150 |
151 |
152 | def prune_fpgm(layers):
153 |
154 | pruned_activations_mask = []
155 | with torch.no_grad():
156 |
157 | for layer in layers:
158 |
159 | b, c, h, w = layer.shape
160 |
161 | P = layer.view((b, c, h * w))
162 |
163 | A = P @ P.transpose(1, 2)
164 | A = torch.sum(A, dim=-1)
165 | max_ = torch.max(A)
166 | min_ = torch.min(A)
167 | A = 1.0 - (A - min_) / (max_ - min_) + 1e-3
168 |
169 | pruned_activations_mask.append(A.to(layer.device))
170 |
171 | return pruned_activations_mask
172 |
173 |
174 | class CDLoss(nn.Module):
175 | """Channel Distillation Loss"""
176 |
177 | def __init__(self, linears=[]):
178 | super().__init__()
179 | self.linears = linears
180 |
181 | def forward(self, stu_features: list, tea_features: list):
182 | loss = 0.
183 | for i, (s, t) in enumerate(zip(stu_features, tea_features)):
184 | if not self.linears[i] is None:
185 | s = self.linears[i](s)
186 | s = s.mean(dim=(2, 3), keepdim=False)
187 | t = t.mean(dim=(2, 3), keepdim=False)
188 | # loss += F.mse_loss(F.normalize(s, p=2, dim=-1), F.normalize(t, p=2, dim=-1))
189 | loss += F.mse_loss(s, t)
190 | return loss
191 |
192 |
193 | class GRAMLoss(nn.Module):
194 | """GRAM Loss"""
195 |
196 | def __init__(self, linears=[]):
197 | super().__init__()
198 | self.linears = linears
199 |
200 | def forward(self, stu_features: list, tea_features: list):
201 | loss = 0.
202 | masks = prune_fpgm(tea_features)
203 | for i, s in enumerate(stu_features):
204 | t = tea_features[i]
205 | if not self.linears[i] is None:
206 | s = self.linears[i](s)
207 | b, c = masks[i].shape
208 | m = masks[i].view((b, c, 1, 1)).detach()
209 | loss += torch.mean(torch.pow(s - t, 2) * m)
210 | return loss
211 |
--------------------------------------------------------------------------------
/misc/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/RGAL/e94cf7b19bff1c1a517592a9d9bcaf521c768e43/misc/framework.png
--------------------------------------------------------------------------------
/train_scratch.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import time
5 | import warnings
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.distributed as dist
12 | import torch.optim
13 | import torch.multiprocessing as mp
14 | import torch.utils.data
15 | import torch.utils.data.distributed
16 |
17 | import registry
18 | import datafree
19 |
20 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
21 | # Basic Settings
22 | parser.add_argument('--data_root', default='data')
23 | parser.add_argument('--model', default='resnet34_imagenet')
24 | parser.add_argument('--dataset', default='cifar10')
25 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
26 | help='number of total epochs to run')
27 | parser.add_argument('-b', '--batch-size', default=128, type=int,
28 | metavar='N',
29 | help='mini-batch size (default: 256), this is the total '
30 | 'batch size of all GPUs on the current node when '
31 | 'using Data Parallel or Distributed Data Parallel')
32 | parser.add_argument('--warm_up_epoches', default=10, type=int,
33 | metavar='WPI', help='warm up epoches')
34 | parser.add_argument('--warm_up_lr', default=0.01, type=int,
35 | metavar='WPI', help='warm up learning rate')
36 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
37 | metavar='LR', help='initial learning rate', dest='lr')
38 | parser.add_argument('--lr_decay_milestones', default="50,75", type=str,
39 | help='milestones for learning rate decay')
40 | parser.add_argument('--evaluate_only', action='store_true',
41 | help='evaluate model on validation set')
42 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
43 | help='path to latest checkpoint (default: none)')
44 | parser.add_argument('--gpu', default=0, type=int,
45 | help='GPU id to use.')
46 |
47 | # Device & FP16
48 | parser.add_argument('--fp16', action='store_true',
49 | help='use fp16')
50 | parser.add_argument('--world-size', default=-1, type=int,
51 | help='number of nodes for distributed training')
52 | parser.add_argument('--rank', default=-1, type=int,
53 | help='node rank for distributed training')
54 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
55 | help='url used to set up distributed training')
56 | parser.add_argument('--dist-backend', default='nccl', type=str,
57 | help='distributed backend')
58 | parser.add_argument('--multiprocessing-distributed', action='store_true',
59 | help='Use multi-processing distributed training to launch '
60 | 'N processes per node, which has N GPUs. This is the '
61 | 'fastest way to use PyTorch for either single node or '
62 | 'multi node data parallel training')
63 |
64 | # Misc
65 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
66 | help='number of data loading workers (default: 4)')
67 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
68 | help='manual epoch number (useful on restarts)')
69 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
70 | help='momentum')
71 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
72 | help='use pre-trained model')
73 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
74 | metavar='W', help='weight decay (default: 1e-4)',
75 | dest='weight_decay')
76 | parser.add_argument('-p', '--print-freq', default=0, type=int,
77 | metavar='N', help='print frequency (default: 0)')
78 | parser.add_argument('--seed', default=None, type=int,
79 | help='seed for initializing training.')
80 |
81 | best_acc1 = 0
82 |
83 |
84 | def main():
85 | args = parser.parse_args()
86 | if args.seed is not None:
87 | random.seed(args.seed)
88 | torch.manual_seed(args.seed)
89 | cudnn.deterministic = True
90 | warnings.warn('You have chosen to seed training. '
91 | 'This will turn on the CUDNN deterministic setting, '
92 | 'which can slow down your training considerably! '
93 | 'You may see unexpected behavior when restarting '
94 | 'from checkpoints.')
95 | if args.gpu is not None:
96 | warnings.warn('You have chosen a specific GPU. This will completely '
97 | 'disable data parallelism.')
98 | if args.dist_url == "env://" and args.world_size == -1:
99 | args.world_size = int(os.environ["WORLD_SIZE"])
100 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
101 | args.ngpus_per_node = ngpus_per_node = torch.cuda.device_count()
102 | if args.multiprocessing_distributed:
103 | # Since we have ngpus_per_node processes per node, the total world_size
104 | # needs to be adjusted accordingly
105 | args.world_size = ngpus_per_node * args.world_size
106 | # Use torch.multiprocessing.spawn to launch distributed processes: the
107 | # main_worker process function
108 | mp.spawn(main_worker, nprocs=ngpus_per_node,
109 | args=(ngpus_per_node, args))
110 | else:
111 | # Simply call main_worker function
112 | main_worker(args.gpu, ngpus_per_node, args)
113 |
114 |
115 | def main_worker(gpu, ngpus_per_node, args):
116 | global best_acc1
117 | args.gpu = gpu
118 |
119 | ############################################
120 | # GPU and FP16
121 | ############################################
122 | if args.gpu is not None:
123 | print("Use GPU: {} for training".format(args.gpu))
124 | if args.distributed:
125 | if args.dist_url == "env://" and args.rank == -1:
126 | args.rank = int(os.environ["RANK"])
127 | if args.multiprocessing_distributed:
128 | # For multiprocessing distributed training, rank needs to be the
129 | # global rank among all the processes
130 | args.rank = args.rank * ngpus_per_node + gpu
131 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
132 | world_size=args.world_size, rank=args.rank)
133 | if args.fp16:
134 | from torch.cuda.amp import autocast, GradScaler
135 | args.scaler = GradScaler() if args.fp16 else None
136 | args.autocast = autocast
137 | else:
138 | args.autocast = datafree.utils.dummy_ctx
139 |
140 | ############################################
141 | # Logger
142 | ############################################
143 | log_name = 'R%d-%s-%s' % (args.rank, args.dataset,
144 | args.model) if args.multiprocessing_distributed else '%s-%s' % (args.dataset, args.model)
145 | args.logger = datafree.utils.logger.get_logger(
146 | log_name, output='checkpoints/scratch/log-%s-%s.txt' % (args.dataset, args.model))
147 | if args.rank <= 0:
148 | # print args
149 | for k, v in datafree.utils.flatten_dict(vars(args)).items():
150 | args.logger.info("%s: %s" % (k, v))
151 |
152 | ############################################
153 | # Setup dataset
154 | ############################################
155 | num_classes, train_dataset, val_dataset = registry.get_dataset(
156 | name=args.dataset, data_root=args.data_root)
157 | cudnn.benchmark = True
158 | if args.distributed:
159 | train_sampler = torch.utils.data.distributed.DistributedSampler(
160 | train_dataset)
161 | else:
162 | train_sampler = None
163 | train_loader = torch.utils.data.DataLoader(
164 | train_dataset, batch_size=args.batch_size, shuffle=(
165 | train_sampler is None),
166 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
167 | val_loader = torch.utils.data.DataLoader(
168 | val_dataset,
169 | batch_size=args.batch_size, shuffle=False,
170 | num_workers=args.workers, pin_memory=True)
171 | evaluator = datafree.evaluators.classification_evaluator(val_loader)
172 | args.current_epoch = 0
173 |
174 | ############################################
175 | # Setup models and datasets
176 | ############################################
177 | model = registry.get_model(
178 | args.model, num_classes=num_classes, pretrained=args.pretrained)
179 | if not torch.cuda.is_available():
180 | print('using CPU, this will be slow')
181 | elif args.distributed:
182 | # For multiprocessing distributed, DistributedDataParallel constructor
183 | # should always set the single device scope, otherwise,
184 | # DistributedDataParallel will use all available devices.
185 | if args.gpu is not None:
186 | torch.cuda.set_device(args.gpu)
187 | model.cuda(args.gpu)
188 | # When using a single GPU per process and per
189 | # DistributedDataParallel, we need to divide the batch size
190 | # ourselves based on the total number of GPUs we have
191 | args.batch_size = int(args.batch_size / ngpus_per_node)
192 | args.workers = int(
193 | (args.workers + ngpus_per_node - 1) / ngpus_per_node)
194 | model = torch.nn.parallel.DistributedDataParallel(
195 | model, device_ids=[args.gpu])
196 | else:
197 | model.cuda()
198 | # DistributedDataParallel will divide and allocate batch_size to all
199 | # available GPUs if device_ids are not set
200 | model = torch.nn.parallel.DistributedDataParallel(model)
201 | elif args.gpu is not None:
202 | torch.cuda.set_device(args.gpu)
203 | model = model.cuda(args.gpu)
204 | else:
205 | # DataParallel will divide and allocate batch_size to all available GPUs
206 | model = torch.nn.DataParallel(model).cuda()
207 |
208 | ############################################
209 | # Setup optimizer
210 | ############################################
211 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
212 | optimizer = torch.optim.SGD(model.parameters(
213 | ), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
214 | milestones = [int(ms) for ms in args.lr_decay_milestones.split(',')]
215 | # warm_up_with_multistep_lr
216 | def warm_up_with_multistep_lr(epoch): return (
217 | epoch+1) / args.warm_up_epoches if epoch < args.warm_up_epoches else 0.1**len([m for m in milestones if m <= epoch])
218 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
219 | optimizer, milestones=milestones, gamma=0.1)
220 |
221 | scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=warm_up_with_multistep_lr)
222 |
223 | ############################################
224 | # Resume
225 | ############################################
226 | if args.resume:
227 | if os.path.isfile(args.resume):
228 | print("=> loading checkpoint '{}'".format(args.resume))
229 | if args.gpu is None:
230 | checkpoint = torch.load(args.resume)
231 | else:
232 | # Map model to be loaded to specified single gpu.
233 | loc = 'cuda:{}'.format(args.gpu)
234 | checkpoint = torch.load(args.resume, map_location=loc)
235 |
236 | if isinstance(model, nn.Module):
237 | model.load_state_dict(checkpoint['state_dict'])
238 | else:
239 | model.module.load_state_dict(checkpoint['state_dict'])
240 |
241 | try:
242 | if 'best_acc1' in checkpoint:
243 | best_acc1 = checkpoint['best_acc1']
244 | args.start_epoch = checkpoint['epoch']
245 | optimizer.load_state_dict(checkpoint['optimizer'])
246 | scheduler.load_state_dict(checkpoint['scheduler'])
247 | except:
248 | print("Fails to load additional information")
249 | print("[!] loaded checkpoint '{}' (epoch {} acc {})"
250 | .format(args.resume, checkpoint['epoch'], best_acc1))
251 | else:
252 | print("[!] no checkpoint found at '{}'".format(args.resume))
253 |
254 | ############################################
255 | # Evaluate
256 | ############################################
257 | if args.evaluate_only:
258 | model.eval()
259 | eval_results = evaluator(model, device=args.gpu)
260 | (acc1, acc5), val_loss = eval_results['Acc'], eval_results['Loss']
261 | print('[Eval] Acc@1={acc1:.4f} Acc@5={acc5:.4f} Loss={loss:.4f}'.format(
262 | acc1=acc1, acc5=acc5, loss=val_loss))
263 | return
264 |
265 | ############################################
266 | # Train Loop
267 | ############################################
268 | for epoch in range(args.start_epoch, args.epochs):
269 | if args.distributed:
270 | train_sampler.set_epoch(epoch)
271 | args.current_epoch = epoch
272 | train(train_loader, model, criterion, optimizer, args)
273 | model.eval()
274 | eval_results = evaluator(model, device=args.gpu)
275 | (acc1, acc5), val_loss = eval_results['Acc'], eval_results['Loss']
276 | args.logger.info('[Eval] Epoch={current_epoch} Acc@1={acc1:.4f} Acc@5={acc5:.4f} Loss={loss:.4f} Lr={lr:.4f}'
277 | .format(current_epoch=args.current_epoch, acc1=acc1, acc5=acc5, loss=val_loss, lr=optimizer.param_groups[0]['lr']))
278 | scheduler.step()
279 | is_best = acc1 > best_acc1
280 | best_acc1 = max(acc1, best_acc1)
281 | _best_ckpt = 'checkpoints/scratch/%s_%s.pth' % (
282 | args.dataset, args.model)
283 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
284 | and args.rank % ngpus_per_node == 0):
285 | save_checkpoint({
286 | 'epoch': epoch + 1,
287 | 'arch': args.model,
288 | 'state_dict': model.state_dict(),
289 | 'best_acc1': float(best_acc1),
290 | 'optimizer': optimizer.state_dict(),
291 | 'scheduler': scheduler.state_dict()
292 | }, is_best, _best_ckpt)
293 | if args.rank <= 0:
294 | args.logger.info("Best: %.4f" % best_acc1)
295 |
296 |
297 | def train(train_loader, model, criterion, optimizer, args):
298 | global best_acc1
299 | loss_metric = datafree.metrics.RunningLoss(
300 | nn.CrossEntropyLoss(reduction='sum'))
301 | acc_metric = datafree.metrics.TopkAccuracy(topk=(1, 5))
302 | model.train()
303 | for i, (images, target) in enumerate(train_loader):
304 | if args.gpu is not None:
305 | images = images.cuda(args.gpu, non_blocking=True)
306 | if torch.cuda.is_available():
307 | target = target.cuda(args.gpu, non_blocking=True)
308 | with args.autocast(enabled=args.fp16):
309 | output = model(images)
310 | loss = criterion(output, target)
311 | # measure accuracy and record loss
312 | acc_metric.update(output, target)
313 | loss_metric.update(output, target)
314 | optimizer.zero_grad()
315 | if args.fp16:
316 | scaler = args.scaler
317 | scaler.scale(loss).backward()
318 | scaler.step(optimizer)
319 | scaler.update()
320 | else:
321 | loss.backward()
322 | optimizer.step()
323 | if args.print_freq > 0 and i % args.print_freq == 0:
324 | (train_acc1, train_acc5), train_loss = acc_metric.get_results(
325 | ), loss_metric.get_results()
326 | args.logger.info('[Train] Epoch={current_epoch} Iter={i}/{total_iters}, train_acc@1={train_acc1:.4f}, train_acc@5={train_acc5:.4f}, train_Loss={train_loss:.4f}, Lr={lr:.4f}'
327 | .format(current_epoch=args.current_epoch, i=i, total_iters=len(train_loader), train_acc1=train_acc1, train_acc5=train_acc5, train_loss=train_loss, lr=optimizer.param_groups[0]['lr']))
328 | loss_metric.reset(), acc_metric.reset()
329 |
330 |
331 | def save_checkpoint(state, is_best, filename='checkpoint.pth'):
332 | if is_best:
333 | torch.save(state, filename)
334 |
335 |
336 | if __name__ == '__main__':
337 | main()
338 |
--------------------------------------------------------------------------------