├── images └── Approach_Diagram.png ├── models ├── __init__.py ├── lenet.py ├── vgg.py ├── mobilenet.py ├── alexnet_half_wo_BN.py ├── alexnet.py ├── alexnet_half.py ├── mobilenetv2.py ├── googlenet.py ├── resnext.py ├── dpn.py ├── densenet.py ├── shufflenet.py ├── senet.py ├── preact_resnet.py ├── dla_simple.py └── pnasnet.py ├── code ├── train_student │ ├── models │ │ ├── __init__.py │ │ ├── lenet.py │ │ ├── vgg.py │ │ ├── mobilenet.py │ │ ├── alexnet_half_wo_BN.py │ │ ├── alexnet.py │ │ ├── alexnet_half.py │ │ ├── mobilenetv2.py │ │ ├── googlenet.py │ │ ├── resnext.py │ │ ├── dpn.py │ │ ├── densenet.py │ │ ├── shufflenet.py │ │ ├── senet.py │ │ ├── preact_resnet.py │ │ ├── dla_simple.py │ │ ├── pnasnet.py │ │ └── resnet.py │ ├── utils.py │ └── dcgan_model.py └── train_generator │ ├── models │ ├── __init__.py │ ├── lenet.py │ ├── vgg.py │ ├── mobilenet.py │ ├── alexnet_half_wo_BN.py │ ├── alexnet.py │ ├── alexnet_half.py │ ├── mobilenetv2.py │ ├── googlenet.py │ ├── resnext.py │ ├── dpn.py │ ├── densenet.py │ ├── shufflenet.py │ ├── senet.py │ ├── preact_resnet.py │ ├── dla_simple.py │ ├── pnasnet.py │ └── resnet.py │ └── dcgan_model.py ├── README.md ├── generate_synthetic_data.py ├── run_cifar40_classes_resnet.sh ├── run_cifar40_classes_alexnet.sh ├── run_cifar10_rand_class_resnet.sh ├── utils.py ├── dcgan_model.py ├── run_cifar10_rand_class_alexnet.sh ├── run_synthetic_alexnet.sh └── run_synthetic_resnet.sh /images/Approach_Diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/Hard-Label-Model-Stealing/HEAD/images/Approach_Diagram.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .dpn import * 3 | from .lenet import * 4 | from .senet import * 5 | from .pnasnet import * 6 | from .densenet import * 7 | from .googlenet import * 8 | from .shufflenet import * 9 | from .shufflenetv2 import * 10 | from .resnet import * 11 | from .resnet_wo_bn import * 12 | from .resnext import * 13 | from .preact_resnet import * 14 | from .mobilenet import * 15 | from .mobilenetv2 import * 16 | from .efficientnet import * 17 | from .regnet import * 18 | from .dla_simple import * 19 | from .dla import * 20 | from .alexnet import * 21 | from .alexnet_half import * 22 | -------------------------------------------------------------------------------- /code/train_student/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .dpn import * 3 | from .lenet import * 4 | from .senet import * 5 | from .pnasnet import * 6 | from .densenet import * 7 | from .googlenet import * 8 | from .shufflenet import * 9 | from .shufflenetv2 import * 10 | from .resnet import * 11 | from .resnet_wo_bn import * 12 | from .resnext import * 13 | from .preact_resnet import * 14 | from .mobilenet import * 15 | from .mobilenetv2 import * 16 | from .efficientnet import * 17 | from .regnet import * 18 | from .dla_simple import * 19 | from .dla import * 20 | from .alexnet import * 21 | from .alexnet_half import * 22 | from .alexnet_half_wo_BN import * -------------------------------------------------------------------------------- /code/train_generator/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .dpn import * 3 | from .lenet import * 4 | from .senet import * 5 | from .pnasnet import * 6 | from .densenet import * 7 | from .googlenet import * 8 | from .shufflenet import * 9 | from .shufflenetv2 import * 10 | from .resnet import * 11 | from .resnet_wo_bn import * 12 | from .resnext import * 13 | from .preact_resnet import * 14 | from .mobilenet import * 15 | from .mobilenetv2 import * 16 | from .efficientnet import * 17 | from .regnet import * 18 | from .dla_simple import * 19 | from .dla import * 20 | from .alexnet import * 21 | from .alexnet_half import * 22 | from .alexnet_half_wo_BN import * -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /code/train_generator/models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /code/train_student/models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Data-Free Model Stealing in a Hard Label Setting 2 | ## CVPR 2022 3 | Sunandini Sanyal, Sravanti Addepalli, R. Venkatesh Babu 4 | 5 | Video Analytics Lab, Indian Institute of Science, Bengaluru 6 | 7 | 8 | ### [[Project Page]](https://sites.google.com/view/dfms-hl) [[Paper]]() 9 | 10 | 11 | ## Approach 12 | 13 | ![Approach_Diagram](https://user-images.githubusercontent.com/19433656/160283244-183fa0f6-a00b-45ed-925e-9d3ae33ec605.png) 14 | 15 | ## Setup the requirements 16 | 17 | The following versions of Pytorch and Tensorflow are needed to run the code. 18 | 19 | Pytorch 1.9.1 20 | 21 | Tensorflow 2.6.0 22 | 23 | ## Run the Model Stealing Attack 24 | The folder contains the code and the script files to run the code with different settings of proxy data. Command to run 10 random classes of CIFAR-100 with AlexNet as victim model and AlexNet-half as clone model: 25 | ``` 26 | ./run_cifar10_rand_class_alexnet.sh 27 | ``` 28 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name): 16 | super(VGG, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.classifier = nn.Linear(512, 10) 19 | 20 | def forward(self, x): 21 | out = self.features(x) 22 | out = out.view(out.size(0), -1) 23 | out = self.classifier(out) 24 | return out 25 | 26 | def _make_layers(self, cfg): 27 | layers = [] 28 | in_channels = 3 29 | for x in cfg: 30 | if x == 'M': 31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 32 | else: 33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(x), 35 | nn.ReLU(inplace=True)] 36 | in_channels = x 37 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 38 | return nn.Sequential(*layers) 39 | 40 | 41 | def test(): 42 | net = VGG('VGG11') 43 | x = torch.randn(2,3,32,32) 44 | y = net(x) 45 | print(y.size()) 46 | 47 | # test() 48 | -------------------------------------------------------------------------------- /code/train_generator/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name): 16 | super(VGG, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.classifier = nn.Linear(512, 10) 19 | 20 | def forward(self, x): 21 | out = self.features(x) 22 | out = out.view(out.size(0), -1) 23 | out = self.classifier(out) 24 | return out 25 | 26 | def _make_layers(self, cfg): 27 | layers = [] 28 | in_channels = 3 29 | for x in cfg: 30 | if x == 'M': 31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 32 | else: 33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(x), 35 | nn.ReLU(inplace=True)] 36 | in_channels = x 37 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 38 | return nn.Sequential(*layers) 39 | 40 | 41 | def test(): 42 | net = VGG('VGG11') 43 | x = torch.randn(2,3,32,32) 44 | y = net(x) 45 | print(y.size()) 46 | 47 | # test() 48 | -------------------------------------------------------------------------------- /code/train_student/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name): 16 | super(VGG, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.classifier = nn.Linear(512, 10) 19 | 20 | def forward(self, x): 21 | out = self.features(x) 22 | out = out.view(out.size(0), -1) 23 | out = self.classifier(out) 24 | return out 25 | 26 | def _make_layers(self, cfg): 27 | layers = [] 28 | in_channels = 3 29 | for x in cfg: 30 | if x == 'M': 31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 32 | else: 33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(x), 35 | nn.ReLU(inplace=True)] 36 | in_channels = x 37 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 38 | return nn.Sequential(*layers) 39 | 40 | 41 | def test(): 42 | net = VGG('VGG11') 43 | x = torch.randn(2,3,32,32) 44 | y = net(x) 45 | print(y.size()) 46 | 47 | # test() 48 | -------------------------------------------------------------------------------- /generate_synthetic_data.py: -------------------------------------------------------------------------------- 1 | import skimage.draw 2 | from PIL import Image 3 | import numpy as np 4 | import random 5 | import os 6 | import cv2 7 | 8 | def _generate_random_colors(num_colors, num_channels, intensity_range): 9 | if num_channels == 1: 10 | intensity_range = (intensity_range, ) 11 | elif len(intensity_range) == 1: 12 | intensity_range = intensity_range * num_channels 13 | colors = [np.random.randint(r[0], r[1] + 1, size=num_colors) 14 | for r in intensity_range] 15 | return np.transpose(colors) 16 | 17 | min_size = 5 #5 18 | max_size = 10 #6 19 | #arr = [130000, 150000, 220000] 20 | arr = [50] 21 | count = 0 22 | #arr = [200000, 300000, 500000] 23 | #for num_shapes in [200, 250, 300]: 24 | for num_shapes in [50]: 25 | len_ = arr[count] 26 | count+=1 27 | for image_num in range(len_): 28 | #image, labels = skimage.draw.random_shapes((32, 32), max_shapes=num_shapes, min_size=12, max_size=22) 29 | image, labels = skimage.draw.random_shapes((100, 100), max_shapes=num_shapes, min_shapes=num_shapes, min_size=min_size, max_size=max_size, allow_overlap=True) 30 | num = _generate_random_colors(1, 3, ((0, 254),)) 31 | #print(num[0]) 32 | for i in range(image.shape[0]): 33 | for j in range(image.shape[1]): 34 | for k in range(3): 35 | if image[i,j,k]==255: 36 | image[i,j,k] = num[0][k] 37 | image = cv2.blur(image, (4,4)) 38 | image = cv2.resize(image, (32,32), interpolation = cv2.INTER_NEAREST) 39 | im = Image.fromarray(image) 40 | im.save("./synthetic_data/50k_samples/file_name_" + str(num_shapes) + "_" + str(image_num)+".png") 41 | 42 | 43 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''Depthwise conv + Pointwise conv''' 13 | def __init__(self, in_planes, out_planes, stride=1): 14 | super(Block, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.bn1(self.conv1(x))) 22 | out = F.relu(self.bn2(self.conv2(out))) 23 | return out 24 | 25 | 26 | class MobileNet(nn.Module): 27 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 28 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 29 | 30 | def __init__(self, num_classes=10): 31 | super(MobileNet, self).__init__() 32 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(32) 34 | self.layers = self._make_layers(in_planes=32) 35 | self.linear = nn.Linear(1024, num_classes) 36 | 37 | def _make_layers(self, in_planes): 38 | layers = [] 39 | for x in self.cfg: 40 | out_planes = x if isinstance(x, int) else x[0] 41 | stride = 1 if isinstance(x, int) else x[1] 42 | layers.append(Block(in_planes, out_planes, stride)) 43 | in_planes = out_planes 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.layers(out) 49 | out = F.avg_pool2d(out, 2) 50 | out = out.view(out.size(0), -1) 51 | out = self.linear(out) 52 | return out 53 | 54 | 55 | def test(): 56 | net = MobileNet() 57 | x = torch.randn(1,3,32,32) 58 | y = net(x) 59 | print(y.size()) 60 | 61 | # test() 62 | -------------------------------------------------------------------------------- /code/train_generator/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''Depthwise conv + Pointwise conv''' 13 | def __init__(self, in_planes, out_planes, stride=1): 14 | super(Block, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.bn1(self.conv1(x))) 22 | out = F.relu(self.bn2(self.conv2(out))) 23 | return out 24 | 25 | 26 | class MobileNet(nn.Module): 27 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 28 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 29 | 30 | def __init__(self, num_classes=10): 31 | super(MobileNet, self).__init__() 32 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(32) 34 | self.layers = self._make_layers(in_planes=32) 35 | self.linear = nn.Linear(1024, num_classes) 36 | 37 | def _make_layers(self, in_planes): 38 | layers = [] 39 | for x in self.cfg: 40 | out_planes = x if isinstance(x, int) else x[0] 41 | stride = 1 if isinstance(x, int) else x[1] 42 | layers.append(Block(in_planes, out_planes, stride)) 43 | in_planes = out_planes 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.layers(out) 49 | out = F.avg_pool2d(out, 2) 50 | out = out.view(out.size(0), -1) 51 | out = self.linear(out) 52 | return out 53 | 54 | 55 | def test(): 56 | net = MobileNet() 57 | x = torch.randn(1,3,32,32) 58 | y = net(x) 59 | print(y.size()) 60 | 61 | # test() 62 | -------------------------------------------------------------------------------- /code/train_student/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''Depthwise conv + Pointwise conv''' 13 | def __init__(self, in_planes, out_planes, stride=1): 14 | super(Block, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.bn1(self.conv1(x))) 22 | out = F.relu(self.bn2(self.conv2(out))) 23 | return out 24 | 25 | 26 | class MobileNet(nn.Module): 27 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 28 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 29 | 30 | def __init__(self, num_classes=10): 31 | super(MobileNet, self).__init__() 32 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(32) 34 | self.layers = self._make_layers(in_planes=32) 35 | self.linear = nn.Linear(1024, num_classes) 36 | 37 | def _make_layers(self, in_planes): 38 | layers = [] 39 | for x in self.cfg: 40 | out_planes = x if isinstance(x, int) else x[0] 41 | stride = 1 if isinstance(x, int) else x[1] 42 | layers.append(Block(in_planes, out_planes, stride)) 43 | in_planes = out_planes 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.layers(out) 49 | out = F.avg_pool2d(out, 2) 50 | out = out.view(out.size(0), -1) 51 | out = self.linear(out) 52 | return out 53 | 54 | 55 | def test(): 56 | net = MobileNet() 57 | x = torch.randn(1,3,32,32) 58 | y = net(x) 59 | print(y.size()) 60 | 61 | # test() 62 | -------------------------------------------------------------------------------- /models/alexnet_half_wo_BN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AlexNet_half(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(AlexNet_half, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(3, 24, 5, stride=1, padding=2) 9 | self.conv1.bias.data.normal_(0, 0.01) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | self.lrn = nn.LocalResponseNorm(2) 14 | self.pad = nn.MaxPool2d(3, stride=2) 15 | 16 | #self.batch_norm1 = nn.BatchNorm2d(24, eps=0.001) 17 | 18 | self.conv2 = nn.Conv2d(24, 64, 5, stride=1, padding=2) 19 | self.conv2.bias.data.normal_(0, 0.01) 20 | self.conv2.bias.data.fill_(1.0) 21 | 22 | #self.batch_norm2 = nn.BatchNorm2d(64, eps=0.001) 23 | 24 | self.conv3 = nn.Conv2d(64, 96, 3, stride=1, padding=1) 25 | self.conv3.bias.data.normal_(0, 0.01) 26 | self.conv3.bias.data.fill_(0) 27 | 28 | #self.batch_norm3 = nn.BatchNorm2d(96, eps=0.001) 29 | 30 | self.conv4 = nn.Conv2d(96, 96, 3, stride=1, padding=1) 31 | self.conv4.bias.data.normal_(0, 0.01) 32 | self.conv4.bias.data.fill_(1.0) 33 | 34 | #self.batch_norm4 = nn.BatchNorm2d(96, eps=0.001) 35 | 36 | self.conv5 = nn.Conv2d(96, 64, 3, stride=1, padding=1) 37 | self.conv5.bias.data.normal_(0, 0.01) 38 | self.conv5.bias.data.fill_(1.0) 39 | 40 | #self.batch_norm5 = nn.BatchNorm2d(64, eps=0.001) 41 | 42 | self.fc1 = nn.Linear(576,256) 43 | self.fc1.bias.data.normal_(0, 0.01) 44 | self.fc1.bias.data.fill_(0) 45 | 46 | self.drop = nn.Dropout(p=0.5) 47 | 48 | #self.batch_norm6 = nn.BatchNorm1d(256, eps=0.001) 49 | 50 | self.fc2 = nn.Linear(256,128) 51 | self.fc2.bias.data.normal_(0, 0.01) 52 | self.fc2.bias.data.fill_(0) 53 | 54 | #self.batch_norm7 = nn.BatchNorm1d(128, eps=0.001) 55 | 56 | self.fc3 = nn.Linear(128,10) 57 | self.fc3.bias.data.normal_(0, 0.01) 58 | self.fc3.bias.data.fill_(0) 59 | 60 | self.soft = nn.Softmax() 61 | 62 | def forward(self, x): 63 | layer1 = self.pad(self.lrn(self.relu(self.conv1(x)))) 64 | layer2 = self.pad(self.lrn(self.relu(self.conv2(layer1)))) 65 | layer3 = self.relu(self.conv3(layer2)) 66 | layer4 = self.relu(self.conv4(layer3)) 67 | layer5 = self.pad(self.relu(self.conv5(layer4))) 68 | flatten = layer5.view(-1, 64*3*3) 69 | fully1 = self.relu(self.fc1(flatten)) 70 | fully1 = self.drop(fully1) 71 | fully2 = self.relu(self.fc2(fully1)) 72 | fully2 = self.drop(fully2) 73 | logits = self.fc3(fully2) 74 | #softmax_val = self.soft(logits) 75 | 76 | return logits 77 | 78 | model = AlexNet_half(10) 79 | -------------------------------------------------------------------------------- /code/train_student/models/alexnet_half_wo_BN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AlexNet_half_wo_BN(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(AlexNet_half_wo_BN, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(3, 24, 5, stride=1, padding=2) 9 | self.conv1.bias.data.normal_(0, 0.01) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | self.lrn = nn.LocalResponseNorm(2) 14 | self.pad = nn.MaxPool2d(3, stride=2) 15 | 16 | #self.batch_norm1 = nn.BatchNorm2d(24, eps=0.001) 17 | 18 | self.conv2 = nn.Conv2d(24, 64, 5, stride=1, padding=2) 19 | self.conv2.bias.data.normal_(0, 0.01) 20 | self.conv2.bias.data.fill_(1.0) 21 | 22 | #self.batch_norm2 = nn.BatchNorm2d(64, eps=0.001) 23 | 24 | self.conv3 = nn.Conv2d(64, 96, 3, stride=1, padding=1) 25 | self.conv3.bias.data.normal_(0, 0.01) 26 | self.conv3.bias.data.fill_(0) 27 | 28 | #self.batch_norm3 = nn.BatchNorm2d(96, eps=0.001) 29 | 30 | self.conv4 = nn.Conv2d(96, 96, 3, stride=1, padding=1) 31 | self.conv4.bias.data.normal_(0, 0.01) 32 | self.conv4.bias.data.fill_(1.0) 33 | 34 | #self.batch_norm4 = nn.BatchNorm2d(96, eps=0.001) 35 | 36 | self.conv5 = nn.Conv2d(96, 64, 3, stride=1, padding=1) 37 | self.conv5.bias.data.normal_(0, 0.01) 38 | self.conv5.bias.data.fill_(1.0) 39 | 40 | #self.batch_norm5 = nn.BatchNorm2d(64, eps=0.001) 41 | 42 | self.fc1 = nn.Linear(576,256) 43 | self.fc1.bias.data.normal_(0, 0.01) 44 | self.fc1.bias.data.fill_(0) 45 | 46 | self.drop = nn.Dropout(p=0.5) 47 | 48 | #self.batch_norm6 = nn.BatchNorm1d(256, eps=0.001) 49 | 50 | self.fc2 = nn.Linear(256,128) 51 | self.fc2.bias.data.normal_(0, 0.01) 52 | self.fc2.bias.data.fill_(0) 53 | 54 | #self.batch_norm7 = nn.BatchNorm1d(128, eps=0.001) 55 | 56 | self.fc3 = nn.Linear(128,10) 57 | self.fc3.bias.data.normal_(0, 0.01) 58 | self.fc3.bias.data.fill_(0) 59 | 60 | self.soft = nn.Softmax() 61 | 62 | def forward(self, x): 63 | layer1 = self.pad(self.lrn(self.relu(self.conv1(x)))) 64 | layer2 = self.pad(self.lrn(self.relu(self.conv2(layer1)))) 65 | layer3 = self.relu(self.conv3(layer2)) 66 | layer4 = self.relu(self.conv4(layer3)) 67 | layer5 = self.pad(self.relu(self.conv5(layer4))) 68 | flatten = layer5.view(-1, 64*3*3) 69 | fully1 = self.relu(self.fc1(flatten)) 70 | fully1 = self.drop(fully1) 71 | fully2 = self.relu(self.fc2(fully1)) 72 | fully2 = self.drop(fully2) 73 | logits = self.fc3(fully2) 74 | #softmax_val = self.soft(logits) 75 | 76 | return logits 77 | 78 | model = AlexNet_half_wo_BN(10) 79 | -------------------------------------------------------------------------------- /code/train_generator/models/alexnet_half_wo_BN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AlexNet_half_wo_BN(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(AlexNet_half_wo_BN, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(3, 24, 5, stride=1, padding=2) 9 | self.conv1.bias.data.normal_(0, 0.01) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | self.lrn = nn.LocalResponseNorm(2) 14 | self.pad = nn.MaxPool2d(3, stride=2) 15 | 16 | #self.batch_norm1 = nn.BatchNorm2d(24, eps=0.001) 17 | 18 | self.conv2 = nn.Conv2d(24, 64, 5, stride=1, padding=2) 19 | self.conv2.bias.data.normal_(0, 0.01) 20 | self.conv2.bias.data.fill_(1.0) 21 | 22 | #self.batch_norm2 = nn.BatchNorm2d(64, eps=0.001) 23 | 24 | self.conv3 = nn.Conv2d(64, 96, 3, stride=1, padding=1) 25 | self.conv3.bias.data.normal_(0, 0.01) 26 | self.conv3.bias.data.fill_(0) 27 | 28 | #self.batch_norm3 = nn.BatchNorm2d(96, eps=0.001) 29 | 30 | self.conv4 = nn.Conv2d(96, 96, 3, stride=1, padding=1) 31 | self.conv4.bias.data.normal_(0, 0.01) 32 | self.conv4.bias.data.fill_(1.0) 33 | 34 | #self.batch_norm4 = nn.BatchNorm2d(96, eps=0.001) 35 | 36 | self.conv5 = nn.Conv2d(96, 64, 3, stride=1, padding=1) 37 | self.conv5.bias.data.normal_(0, 0.01) 38 | self.conv5.bias.data.fill_(1.0) 39 | 40 | #self.batch_norm5 = nn.BatchNorm2d(64, eps=0.001) 41 | 42 | self.fc1 = nn.Linear(576,256) 43 | self.fc1.bias.data.normal_(0, 0.01) 44 | self.fc1.bias.data.fill_(0) 45 | 46 | self.drop = nn.Dropout(p=0.5) 47 | 48 | #self.batch_norm6 = nn.BatchNorm1d(256, eps=0.001) 49 | 50 | self.fc2 = nn.Linear(256,128) 51 | self.fc2.bias.data.normal_(0, 0.01) 52 | self.fc2.bias.data.fill_(0) 53 | 54 | #self.batch_norm7 = nn.BatchNorm1d(128, eps=0.001) 55 | 56 | self.fc3 = nn.Linear(128,10) 57 | self.fc3.bias.data.normal_(0, 0.01) 58 | self.fc3.bias.data.fill_(0) 59 | 60 | self.soft = nn.Softmax() 61 | 62 | def forward(self, x): 63 | layer1 = self.pad(self.lrn(self.relu(self.conv1(x)))) 64 | layer2 = self.pad(self.lrn(self.relu(self.conv2(layer1)))) 65 | layer3 = self.relu(self.conv3(layer2)) 66 | layer4 = self.relu(self.conv4(layer3)) 67 | layer5 = self.pad(self.relu(self.conv5(layer4))) 68 | flatten = layer5.view(-1, 64*3*3) 69 | fully1 = self.relu(self.fc1(flatten)) 70 | fully1 = self.drop(fully1) 71 | fully2 = self.relu(self.fc2(fully1)) 72 | fully2 = self.drop(fully2) 73 | logits = self.fc3(fully2) 74 | #softmax_val = self.soft(logits) 75 | 76 | return logits 77 | 78 | model = AlexNet_half_wo_BN(10) 79 | -------------------------------------------------------------------------------- /code/train_generator/models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AlexNet(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(AlexNet, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(3, 48, 5, stride=1, padding=2) 9 | self.conv1.bias.data.normal_(0, 0.01) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | self.lrn = nn.LocalResponseNorm(2) 14 | self.pad = nn.MaxPool2d(3, stride=2) 15 | 16 | self.batch_norm1 = nn.BatchNorm2d(48, eps=0.001) 17 | 18 | self.conv2 = nn.Conv2d(48, 128, 5, stride=1, padding=2) 19 | self.conv2.bias.data.normal_(0, 0.01) 20 | self.conv2.bias.data.fill_(1.0) 21 | 22 | self.batch_norm2 = nn.BatchNorm2d(128, eps=0.001) 23 | 24 | self.conv3 = nn.Conv2d(128, 192, 3, stride=1, padding=1) 25 | self.conv3.bias.data.normal_(0, 0.01) 26 | self.conv3.bias.data.fill_(0) 27 | 28 | self.batch_norm3 = nn.BatchNorm2d(192, eps=0.001) 29 | 30 | self.conv4 = nn.Conv2d(192, 192, 3, stride=1, padding=1) 31 | self.conv4.bias.data.normal_(0, 0.01) 32 | self.conv4.bias.data.fill_(1.0) 33 | 34 | self.batch_norm4 = nn.BatchNorm2d(192, eps=0.001) 35 | 36 | self.conv5 = nn.Conv2d(192, 128, 3, stride=1, padding=1) 37 | self.conv5.bias.data.normal_(0, 0.01) 38 | self.conv5.bias.data.fill_(1.0) 39 | 40 | self.batch_norm5 = nn.BatchNorm2d(128, eps=0.001) 41 | 42 | self.fc1 = nn.Linear(1152,512) 43 | self.fc1.bias.data.normal_(0, 0.01) 44 | self.fc1.bias.data.fill_(0) 45 | 46 | self.drop = nn.Dropout(p=0.5) 47 | 48 | self.batch_norm6 = nn.BatchNorm1d(512, eps=0.001) 49 | 50 | self.fc2 = nn.Linear(512,256) 51 | self.fc2.bias.data.normal_(0, 0.01) 52 | self.fc2.bias.data.fill_(0) 53 | 54 | self.batch_norm7 = nn.BatchNorm1d(256, eps=0.001) 55 | 56 | self.fc3 = nn.Linear(256,10) 57 | self.fc3.bias.data.normal_(0, 0.01) 58 | self.fc3.bias.data.fill_(0) 59 | 60 | self.soft = nn.Softmax() 61 | 62 | def forward(self, x): 63 | layer1 = self.batch_norm1(self.pad(self.lrn(self.relu(self.conv1(x))))) 64 | layer2 = self.batch_norm2(self.pad(self.lrn(self.relu(self.conv2(layer1))))) 65 | layer3 = self.batch_norm3(self.relu(self.conv3(layer2))) 66 | layer4 = self.batch_norm4(self.relu(self.conv4(layer3))) 67 | layer5 = self.batch_norm5(self.pad(self.relu(self.conv5(layer4)))) 68 | flatten = layer5.view(-1, 128*3*3) 69 | fully1 = self.relu(self.fc1(flatten)) 70 | fully1 = self.batch_norm6(self.drop(fully1)) 71 | fully2 = self.relu(self.fc2(fully1)) 72 | fully2 = self.batch_norm7(self.drop(fully2)) 73 | logits = self.fc3(fully2) 74 | #softmax_val = self.soft(logits) 75 | 76 | return logits 77 | 78 | model = AlexNet(10) 79 | -------------------------------------------------------------------------------- /code/train_student/models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AlexNet(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(AlexNet, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(3, 48, 5, stride=1, padding=2) 9 | self.conv1.bias.data.normal_(0, 0.01) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | self.lrn = nn.LocalResponseNorm(2) 14 | self.pad = nn.MaxPool2d(3, stride=2) 15 | 16 | self.batch_norm1 = nn.BatchNorm2d(48, eps=0.001) 17 | 18 | self.conv2 = nn.Conv2d(48, 128, 5, stride=1, padding=2) 19 | self.conv2.bias.data.normal_(0, 0.01) 20 | self.conv2.bias.data.fill_(1.0) 21 | 22 | self.batch_norm2 = nn.BatchNorm2d(128, eps=0.001) 23 | 24 | self.conv3 = nn.Conv2d(128, 192, 3, stride=1, padding=1) 25 | self.conv3.bias.data.normal_(0, 0.01) 26 | self.conv3.bias.data.fill_(0) 27 | 28 | self.batch_norm3 = nn.BatchNorm2d(192, eps=0.001) 29 | 30 | self.conv4 = nn.Conv2d(192, 192, 3, stride=1, padding=1) 31 | self.conv4.bias.data.normal_(0, 0.01) 32 | self.conv4.bias.data.fill_(1.0) 33 | 34 | self.batch_norm4 = nn.BatchNorm2d(192, eps=0.001) 35 | 36 | self.conv5 = nn.Conv2d(192, 128, 3, stride=1, padding=1) 37 | self.conv5.bias.data.normal_(0, 0.01) 38 | self.conv5.bias.data.fill_(1.0) 39 | 40 | self.batch_norm5 = nn.BatchNorm2d(128, eps=0.001) 41 | 42 | self.fc1 = nn.Linear(1152,512) 43 | self.fc1.bias.data.normal_(0, 0.01) 44 | self.fc1.bias.data.fill_(0) 45 | 46 | self.drop = nn.Dropout(p=0.5) 47 | 48 | self.batch_norm6 = nn.BatchNorm1d(512, eps=0.001) 49 | 50 | self.fc2 = nn.Linear(512,256) 51 | self.fc2.bias.data.normal_(0, 0.01) 52 | self.fc2.bias.data.fill_(0) 53 | 54 | self.batch_norm7 = nn.BatchNorm1d(256, eps=0.001) 55 | 56 | self.fc3 = nn.Linear(256,10) 57 | self.fc3.bias.data.normal_(0, 0.01) 58 | self.fc3.bias.data.fill_(0) 59 | 60 | self.soft = nn.Softmax() 61 | 62 | def forward(self, x): 63 | layer1 = self.batch_norm1(self.pad(self.lrn(self.relu(self.conv1(x))))) 64 | layer2 = self.batch_norm2(self.pad(self.lrn(self.relu(self.conv2(layer1))))) 65 | layer3 = self.batch_norm3(self.relu(self.conv3(layer2))) 66 | layer4 = self.batch_norm4(self.relu(self.conv4(layer3))) 67 | layer5 = self.batch_norm5(self.pad(self.relu(self.conv5(layer4)))) 68 | flatten = layer5.view(-1, 128*3*3) 69 | fully1 = self.relu(self.fc1(flatten)) 70 | fully1 = self.batch_norm6(self.drop(fully1)) 71 | fully2 = self.relu(self.fc2(fully1)) 72 | fully2 = self.batch_norm7(self.drop(fully2)) 73 | logits = self.fc3(fully2) 74 | #softmax_val = self.soft(logits) 75 | 76 | return logits 77 | 78 | model = AlexNet(10) 79 | -------------------------------------------------------------------------------- /code/train_student/models/alexnet_half.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AlexNet_half(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(AlexNet_half, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(3, 24, 5, stride=1, padding=2) 9 | self.conv1.bias.data.normal_(0, 0.01) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | self.lrn = nn.LocalResponseNorm(2) 14 | self.pad = nn.MaxPool2d(3, stride=2) 15 | 16 | self.batch_norm1 = nn.BatchNorm2d(24, eps=0.001) 17 | 18 | self.conv2 = nn.Conv2d(24, 64, 5, stride=1, padding=2) 19 | self.conv2.bias.data.normal_(0, 0.01) 20 | self.conv2.bias.data.fill_(1.0) 21 | 22 | self.batch_norm2 = nn.BatchNorm2d(64, eps=0.001) 23 | 24 | self.conv3 = nn.Conv2d(64, 96, 3, stride=1, padding=1) 25 | self.conv3.bias.data.normal_(0, 0.01) 26 | self.conv3.bias.data.fill_(0) 27 | 28 | self.batch_norm3 = nn.BatchNorm2d(96, eps=0.001) 29 | 30 | self.conv4 = nn.Conv2d(96, 96, 3, stride=1, padding=1) 31 | self.conv4.bias.data.normal_(0, 0.01) 32 | self.conv4.bias.data.fill_(1.0) 33 | 34 | self.batch_norm4 = nn.BatchNorm2d(96, eps=0.001) 35 | 36 | self.conv5 = nn.Conv2d(96, 64, 3, stride=1, padding=1) 37 | self.conv5.bias.data.normal_(0, 0.01) 38 | self.conv5.bias.data.fill_(1.0) 39 | 40 | self.batch_norm5 = nn.BatchNorm2d(64, eps=0.001) 41 | 42 | self.fc1 = nn.Linear(576,256) 43 | self.fc1.bias.data.normal_(0, 0.01) 44 | self.fc1.bias.data.fill_(0) 45 | 46 | self.drop = nn.Dropout(p=0.5) 47 | 48 | self.batch_norm6 = nn.BatchNorm1d(256, eps=0.001) 49 | 50 | self.fc2 = nn.Linear(256,128) 51 | self.fc2.bias.data.normal_(0, 0.01) 52 | self.fc2.bias.data.fill_(0) 53 | 54 | self.batch_norm7 = nn.BatchNorm1d(128, eps=0.001) 55 | 56 | self.fc3 = nn.Linear(128,10) 57 | self.fc3.bias.data.normal_(0, 0.01) 58 | self.fc3.bias.data.fill_(0) 59 | 60 | self.soft = nn.Softmax() 61 | 62 | def forward(self, x): 63 | layer1 = self.batch_norm1(self.pad(self.lrn(self.relu(self.conv1(x))))) 64 | layer2 = self.batch_norm2(self.pad(self.lrn(self.relu(self.conv2(layer1))))) 65 | layer3 = self.batch_norm3(self.relu(self.conv3(layer2))) 66 | layer4 = self.batch_norm4(self.relu(self.conv4(layer3))) 67 | layer5 = self.batch_norm5(self.pad(self.relu(self.conv5(layer4)))) 68 | flatten = layer5.view(-1, 64*3*3) 69 | fully1 = self.relu(self.fc1(flatten)) 70 | fully1 = self.batch_norm6(self.drop(fully1)) 71 | fully2 = self.relu(self.fc2(fully1)) 72 | fully2 = self.batch_norm7(self.drop(fully2)) 73 | logits = self.fc3(fully2) 74 | #softmax_val = self.soft(logits) 75 | 76 | return logits 77 | 78 | model = AlexNet_half(10) 79 | -------------------------------------------------------------------------------- /code/train_generator/models/alexnet_half.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AlexNet_half(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(AlexNet_half, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(3, 24, 5, stride=1, padding=2) 9 | self.conv1.bias.data.normal_(0, 0.01) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | self.lrn = nn.LocalResponseNorm(2) 14 | self.pad = nn.MaxPool2d(3, stride=2) 15 | 16 | self.batch_norm1 = nn.BatchNorm2d(24, eps=0.001) 17 | 18 | self.conv2 = nn.Conv2d(24, 64, 5, stride=1, padding=2) 19 | self.conv2.bias.data.normal_(0, 0.01) 20 | self.conv2.bias.data.fill_(1.0) 21 | 22 | self.batch_norm2 = nn.BatchNorm2d(64, eps=0.001) 23 | 24 | self.conv3 = nn.Conv2d(64, 96, 3, stride=1, padding=1) 25 | self.conv3.bias.data.normal_(0, 0.01) 26 | self.conv3.bias.data.fill_(0) 27 | 28 | self.batch_norm3 = nn.BatchNorm2d(96, eps=0.001) 29 | 30 | self.conv4 = nn.Conv2d(96, 96, 3, stride=1, padding=1) 31 | self.conv4.bias.data.normal_(0, 0.01) 32 | self.conv4.bias.data.fill_(1.0) 33 | 34 | self.batch_norm4 = nn.BatchNorm2d(96, eps=0.001) 35 | 36 | self.conv5 = nn.Conv2d(96, 64, 3, stride=1, padding=1) 37 | self.conv5.bias.data.normal_(0, 0.01) 38 | self.conv5.bias.data.fill_(1.0) 39 | 40 | self.batch_norm5 = nn.BatchNorm2d(64, eps=0.001) 41 | 42 | self.fc1 = nn.Linear(576,256) 43 | self.fc1.bias.data.normal_(0, 0.01) 44 | self.fc1.bias.data.fill_(0) 45 | 46 | self.drop = nn.Dropout(p=0.5) 47 | 48 | self.batch_norm6 = nn.BatchNorm1d(256, eps=0.001) 49 | 50 | self.fc2 = nn.Linear(256,128) 51 | self.fc2.bias.data.normal_(0, 0.01) 52 | self.fc2.bias.data.fill_(0) 53 | 54 | self.batch_norm7 = nn.BatchNorm1d(128, eps=0.001) 55 | 56 | self.fc3 = nn.Linear(128,10) 57 | self.fc3.bias.data.normal_(0, 0.01) 58 | self.fc3.bias.data.fill_(0) 59 | 60 | self.soft = nn.Softmax() 61 | 62 | def forward(self, x): 63 | layer1 = self.batch_norm1(self.pad(self.lrn(self.relu(self.conv1(x))))) 64 | layer2 = self.batch_norm2(self.pad(self.lrn(self.relu(self.conv2(layer1))))) 65 | layer3 = self.batch_norm3(self.relu(self.conv3(layer2))) 66 | layer4 = self.batch_norm4(self.relu(self.conv4(layer3))) 67 | layer5 = self.batch_norm5(self.pad(self.relu(self.conv5(layer4)))) 68 | flatten = layer5.view(-1, 64*3*3) 69 | fully1 = self.relu(self.fc1(flatten)) 70 | fully1 = self.batch_norm6(self.drop(fully1)) 71 | fully2 = self.relu(self.fc2(fully1)) 72 | fully2 = self.batch_norm7(self.drop(fully2)) 73 | logits = self.fc3(fully2) 74 | #softmax_val = self.soft(logits) 75 | 76 | return logits 77 | 78 | model = AlexNet_half(10) 79 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AlexNet(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(AlexNet, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(3, 48, 5, stride=1, padding=2) 9 | self.conv1.bias.data.normal_(0, 0.01) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | self.lrn = nn.LocalResponseNorm(2) 14 | self.pad = nn.MaxPool2d(3, stride=2) 15 | 16 | self.batch_norm1 = nn.BatchNorm2d(48, eps=0.001) 17 | 18 | self.conv2 = nn.Conv2d(48, 128, 5, stride=1, padding=2) 19 | self.conv2.bias.data.normal_(0, 0.01) 20 | self.conv2.bias.data.fill_(1.0) 21 | 22 | self.batch_norm2 = nn.BatchNorm2d(128, eps=0.001) 23 | 24 | self.conv3 = nn.Conv2d(128, 192, 3, stride=1, padding=1) 25 | self.conv3.bias.data.normal_(0, 0.01) 26 | self.conv3.bias.data.fill_(0) 27 | 28 | self.batch_norm3 = nn.BatchNorm2d(192, eps=0.001) 29 | 30 | self.conv4 = nn.Conv2d(192, 192, 3, stride=1, padding=1) 31 | self.conv4.bias.data.normal_(0, 0.01) 32 | self.conv4.bias.data.fill_(1.0) 33 | 34 | self.batch_norm4 = nn.BatchNorm2d(192, eps=0.001) 35 | 36 | self.conv5 = nn.Conv2d(192, 128, 3, stride=1, padding=1) 37 | self.conv5.bias.data.normal_(0, 0.01) 38 | self.conv5.bias.data.fill_(1.0) 39 | 40 | self.batch_norm5 = nn.BatchNorm2d(128, eps=0.001) 41 | 42 | self.fc1 = nn.Linear(1152,512) 43 | self.fc1.bias.data.normal_(0, 0.01) 44 | self.fc1.bias.data.fill_(0) 45 | 46 | self.drop = nn.Dropout(p=0.5) 47 | 48 | self.batch_norm6 = nn.BatchNorm1d(512, eps=0.001) 49 | 50 | self.fc2 = nn.Linear(512,256) 51 | self.fc2.bias.data.normal_(0, 0.01) 52 | self.fc2.bias.data.fill_(0) 53 | 54 | self.batch_norm7 = nn.BatchNorm1d(256, eps=0.001) 55 | 56 | self.fc3 = nn.Linear(256,10) 57 | self.fc3.bias.data.normal_(0, 0.01) 58 | self.fc3.bias.data.fill_(0) 59 | 60 | self.soft = nn.Softmax() 61 | 62 | def forward(self, x, penu=False): 63 | layer1 = self.batch_norm1(self.pad(self.lrn(self.relu(self.conv1(x))))) 64 | layer2 = self.batch_norm2(self.pad(self.lrn(self.relu(self.conv2(layer1))))) 65 | layer3 = self.batch_norm3(self.relu(self.conv3(layer2))) 66 | layer4 = self.batch_norm4(self.relu(self.conv4(layer3))) 67 | layer5 = self.batch_norm5(self.pad(self.relu(self.conv5(layer4)))) 68 | flatten = layer5.view(-1, 128*3*3) 69 | if penu==True: 70 | return flatten 71 | fully1 = self.relu(self.fc1(flatten)) 72 | fully1 = self.batch_norm6(self.drop(fully1)) 73 | fully2 = self.relu(self.fc2(fully1)) 74 | fully2 = self.batch_norm7(self.drop(fully2)) 75 | logits = self.fc3(fully2) 76 | #softmax_val = self.soft(logits) 77 | 78 | return logits 79 | 80 | model = AlexNet(10) 81 | -------------------------------------------------------------------------------- /models/alexnet_half.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AlexNet_half(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(AlexNet_half, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(3, 24, 5, stride=1, padding=2) 9 | self.conv1.bias.data.normal_(0, 0.01) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | self.lrn = nn.LocalResponseNorm(2) 14 | self.pad = nn.MaxPool2d(3, stride=2) 15 | 16 | self.batch_norm1 = nn.BatchNorm2d(24, eps=0.001) 17 | 18 | self.conv2 = nn.Conv2d(24, 64, 5, stride=1, padding=2) 19 | self.conv2.bias.data.normal_(0, 0.01) 20 | self.conv2.bias.data.fill_(1.0) 21 | 22 | self.batch_norm2 = nn.BatchNorm2d(64, eps=0.001) 23 | 24 | self.conv3 = nn.Conv2d(64, 96, 3, stride=1, padding=1) 25 | self.conv3.bias.data.normal_(0, 0.01) 26 | self.conv3.bias.data.fill_(0) 27 | 28 | self.batch_norm3 = nn.BatchNorm2d(96, eps=0.001) 29 | 30 | self.conv4 = nn.Conv2d(96, 96, 3, stride=1, padding=1) 31 | self.conv4.bias.data.normal_(0, 0.01) 32 | self.conv4.bias.data.fill_(1.0) 33 | 34 | self.batch_norm4 = nn.BatchNorm2d(96, eps=0.001) 35 | 36 | self.conv5 = nn.Conv2d(96, 64, 3, stride=1, padding=1) 37 | self.conv5.bias.data.normal_(0, 0.01) 38 | self.conv5.bias.data.fill_(1.0) 39 | 40 | self.batch_norm5 = nn.BatchNorm2d(64, eps=0.001) 41 | 42 | self.fc1 = nn.Linear(576,256) 43 | self.fc1.bias.data.normal_(0, 0.01) 44 | self.fc1.bias.data.fill_(0) 45 | 46 | self.drop = nn.Dropout(p=0.5) 47 | 48 | self.batch_norm6 = nn.BatchNorm1d(256, eps=0.001) 49 | 50 | self.fc2 = nn.Linear(256,128) 51 | self.fc2.bias.data.normal_(0, 0.01) 52 | self.fc2.bias.data.fill_(0) 53 | 54 | self.batch_norm7 = nn.BatchNorm1d(128, eps=0.001) 55 | 56 | self.fc3 = nn.Linear(128,10) 57 | self.fc3.bias.data.normal_(0, 0.01) 58 | self.fc3.bias.data.fill_(0) 59 | 60 | self.soft = nn.Softmax() 61 | 62 | def forward(self, x, penu=False): 63 | layer1 = self.batch_norm1(self.pad(self.lrn(self.relu(self.conv1(x))))) 64 | layer2 = self.batch_norm2(self.pad(self.lrn(self.relu(self.conv2(layer1))))) 65 | layer3 = self.batch_norm3(self.relu(self.conv3(layer2))) 66 | layer4 = self.batch_norm4(self.relu(self.conv4(layer3))) 67 | layer5 = self.batch_norm5(self.pad(self.relu(self.conv5(layer4)))) 68 | flatten = layer5.view(-1, 64*3*3) 69 | if penu==True: 70 | return flatten 71 | fully1 = self.relu(self.fc1(flatten)) 72 | fully1 = self.batch_norm6(self.drop(fully1)) 73 | fully2 = self.relu(self.fc2(fully1)) 74 | fully2 = self.batch_norm7(self.drop(fully2)) 75 | logits = self.fc3(fully2) 76 | #softmax_val = self.soft(logits) 77 | 78 | return logits 79 | 80 | 81 | model = AlexNet_half(10) 82 | -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''expand + depthwise + pointwise''' 13 | def __init__(self, in_planes, out_planes, expansion, stride): 14 | super(Block, self).__init__() 15 | self.stride = stride 16 | 17 | planes = expansion * in_planes 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 23 | self.bn3 = nn.BatchNorm2d(out_planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride == 1 and in_planes != out_planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.BatchNorm2d(out_planes), 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = F.relu(self.bn2(self.conv2(out))) 35 | out = self.bn3(self.conv3(out)) 36 | out = out + self.shortcut(x) if self.stride==1 else out 37 | return out 38 | 39 | 40 | class MobileNetV2(nn.Module): 41 | # (expansion, out_planes, num_blocks, stride) 42 | cfg = [(1, 16, 1, 1), 43 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 44 | (6, 32, 3, 2), 45 | (6, 64, 4, 2), 46 | (6, 96, 3, 1), 47 | (6, 160, 3, 2), 48 | (6, 320, 1, 1)] 49 | 50 | def __init__(self, num_classes=10): 51 | super(MobileNetV2, self).__init__() 52 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 53 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(32) 55 | self.layers = self._make_layers(in_planes=32) 56 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 57 | self.bn2 = nn.BatchNorm2d(1280) 58 | self.linear = nn.Linear(1280, num_classes) 59 | 60 | def _make_layers(self, in_planes): 61 | layers = [] 62 | for expansion, out_planes, num_blocks, stride in self.cfg: 63 | strides = [stride] + [1]*(num_blocks-1) 64 | for stride in strides: 65 | layers.append(Block(in_planes, out_planes, expansion, stride)) 66 | in_planes = out_planes 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = self.layers(out) 72 | out = F.relu(self.bn2(self.conv2(out))) 73 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 74 | out = F.avg_pool2d(out, 4) 75 | out = out.view(out.size(0), -1) 76 | out = self.linear(out) 77 | return out 78 | 79 | 80 | def test(): 81 | net = MobileNetV2() 82 | x = torch.randn(2,3,32,32) 83 | y = net(x) 84 | print(y.size()) 85 | 86 | # test() 87 | -------------------------------------------------------------------------------- /code/train_student/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''expand + depthwise + pointwise''' 13 | def __init__(self, in_planes, out_planes, expansion, stride): 14 | super(Block, self).__init__() 15 | self.stride = stride 16 | 17 | planes = expansion * in_planes 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 23 | self.bn3 = nn.BatchNorm2d(out_planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride == 1 and in_planes != out_planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.BatchNorm2d(out_planes), 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = F.relu(self.bn2(self.conv2(out))) 35 | out = self.bn3(self.conv3(out)) 36 | out = out + self.shortcut(x) if self.stride==1 else out 37 | return out 38 | 39 | 40 | class MobileNetV2(nn.Module): 41 | # (expansion, out_planes, num_blocks, stride) 42 | cfg = [(1, 16, 1, 1), 43 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 44 | (6, 32, 3, 2), 45 | (6, 64, 4, 2), 46 | (6, 96, 3, 1), 47 | (6, 160, 3, 2), 48 | (6, 320, 1, 1)] 49 | 50 | def __init__(self, num_classes=10): 51 | super(MobileNetV2, self).__init__() 52 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 53 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(32) 55 | self.layers = self._make_layers(in_planes=32) 56 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 57 | self.bn2 = nn.BatchNorm2d(1280) 58 | self.linear = nn.Linear(1280, num_classes) 59 | 60 | def _make_layers(self, in_planes): 61 | layers = [] 62 | for expansion, out_planes, num_blocks, stride in self.cfg: 63 | strides = [stride] + [1]*(num_blocks-1) 64 | for stride in strides: 65 | layers.append(Block(in_planes, out_planes, expansion, stride)) 66 | in_planes = out_planes 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = self.layers(out) 72 | out = F.relu(self.bn2(self.conv2(out))) 73 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 74 | out = F.avg_pool2d(out, 4) 75 | out = out.view(out.size(0), -1) 76 | out = self.linear(out) 77 | return out 78 | 79 | 80 | def test(): 81 | net = MobileNetV2() 82 | x = torch.randn(2,3,32,32) 83 | y = net(x) 84 | print(y.size()) 85 | 86 | # test() 87 | -------------------------------------------------------------------------------- /code/train_generator/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''expand + depthwise + pointwise''' 13 | def __init__(self, in_planes, out_planes, expansion, stride): 14 | super(Block, self).__init__() 15 | self.stride = stride 16 | 17 | planes = expansion * in_planes 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 23 | self.bn3 = nn.BatchNorm2d(out_planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride == 1 and in_planes != out_planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.BatchNorm2d(out_planes), 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = F.relu(self.bn2(self.conv2(out))) 35 | out = self.bn3(self.conv3(out)) 36 | out = out + self.shortcut(x) if self.stride==1 else out 37 | return out 38 | 39 | 40 | class MobileNetV2(nn.Module): 41 | # (expansion, out_planes, num_blocks, stride) 42 | cfg = [(1, 16, 1, 1), 43 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 44 | (6, 32, 3, 2), 45 | (6, 64, 4, 2), 46 | (6, 96, 3, 1), 47 | (6, 160, 3, 2), 48 | (6, 320, 1, 1)] 49 | 50 | def __init__(self, num_classes=10): 51 | super(MobileNetV2, self).__init__() 52 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 53 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(32) 55 | self.layers = self._make_layers(in_planes=32) 56 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 57 | self.bn2 = nn.BatchNorm2d(1280) 58 | self.linear = nn.Linear(1280, num_classes) 59 | 60 | def _make_layers(self, in_planes): 61 | layers = [] 62 | for expansion, out_planes, num_blocks, stride in self.cfg: 63 | strides = [stride] + [1]*(num_blocks-1) 64 | for stride in strides: 65 | layers.append(Block(in_planes, out_planes, expansion, stride)) 66 | in_planes = out_planes 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = self.layers(out) 72 | out = F.relu(self.bn2(self.conv2(out))) 73 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 74 | out = F.avg_pool2d(out, 4) 75 | out = out.view(out.size(0), -1) 76 | out = self.linear(out) 77 | return out 78 | 79 | 80 | def test(): 81 | net = MobileNetV2() 82 | x = torch.randn(2,3,32,32) 83 | y = net(x) 84 | print(y.size()) 85 | 86 | # test() 87 | -------------------------------------------------------------------------------- /models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Inception(nn.Module): 8 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 9 | super(Inception, self).__init__() 10 | # 1x1 conv branch 11 | self.b1 = nn.Sequential( 12 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 13 | nn.BatchNorm2d(n1x1), 14 | nn.ReLU(True), 15 | ) 16 | 17 | # 1x1 conv -> 3x3 conv branch 18 | self.b2 = nn.Sequential( 19 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 20 | nn.BatchNorm2d(n3x3red), 21 | nn.ReLU(True), 22 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(n3x3), 24 | nn.ReLU(True), 25 | ) 26 | 27 | # 1x1 conv -> 5x5 conv branch 28 | self.b3 = nn.Sequential( 29 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 30 | nn.BatchNorm2d(n5x5red), 31 | nn.ReLU(True), 32 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(n5x5), 34 | nn.ReLU(True), 35 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(n5x5), 37 | nn.ReLU(True), 38 | ) 39 | 40 | # 3x3 pool -> 1x1 conv branch 41 | self.b4 = nn.Sequential( 42 | nn.MaxPool2d(3, stride=1, padding=1), 43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 44 | nn.BatchNorm2d(pool_planes), 45 | nn.ReLU(True), 46 | ) 47 | 48 | def forward(self, x): 49 | y1 = self.b1(x) 50 | y2 = self.b2(x) 51 | y3 = self.b3(x) 52 | y4 = self.b4(x) 53 | return torch.cat([y1,y2,y3,y4], 1) 54 | 55 | 56 | class GoogLeNet(nn.Module): 57 | def __init__(self): 58 | super(GoogLeNet, self).__init__() 59 | self.pre_layers = nn.Sequential( 60 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(192), 62 | nn.ReLU(True), 63 | ) 64 | 65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 67 | 68 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 69 | 70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 75 | 76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 78 | 79 | self.avgpool = nn.AvgPool2d(8, stride=1) 80 | self.linear = nn.Linear(1024, 10) 81 | 82 | def forward(self, x): 83 | out = self.pre_layers(x) 84 | out = self.a3(out) 85 | out = self.b3(out) 86 | out = self.maxpool(out) 87 | out = self.a4(out) 88 | out = self.b4(out) 89 | out = self.c4(out) 90 | out = self.d4(out) 91 | out = self.e4(out) 92 | out = self.maxpool(out) 93 | out = self.a5(out) 94 | out = self.b5(out) 95 | out = self.avgpool(out) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def test(): 102 | net = GoogLeNet() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y.size()) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /code/train_student/models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Inception(nn.Module): 8 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 9 | super(Inception, self).__init__() 10 | # 1x1 conv branch 11 | self.b1 = nn.Sequential( 12 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 13 | nn.BatchNorm2d(n1x1), 14 | nn.ReLU(True), 15 | ) 16 | 17 | # 1x1 conv -> 3x3 conv branch 18 | self.b2 = nn.Sequential( 19 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 20 | nn.BatchNorm2d(n3x3red), 21 | nn.ReLU(True), 22 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(n3x3), 24 | nn.ReLU(True), 25 | ) 26 | 27 | # 1x1 conv -> 5x5 conv branch 28 | self.b3 = nn.Sequential( 29 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 30 | nn.BatchNorm2d(n5x5red), 31 | nn.ReLU(True), 32 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(n5x5), 34 | nn.ReLU(True), 35 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(n5x5), 37 | nn.ReLU(True), 38 | ) 39 | 40 | # 3x3 pool -> 1x1 conv branch 41 | self.b4 = nn.Sequential( 42 | nn.MaxPool2d(3, stride=1, padding=1), 43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 44 | nn.BatchNorm2d(pool_planes), 45 | nn.ReLU(True), 46 | ) 47 | 48 | def forward(self, x): 49 | y1 = self.b1(x) 50 | y2 = self.b2(x) 51 | y3 = self.b3(x) 52 | y4 = self.b4(x) 53 | return torch.cat([y1,y2,y3,y4], 1) 54 | 55 | 56 | class GoogLeNet(nn.Module): 57 | def __init__(self): 58 | super(GoogLeNet, self).__init__() 59 | self.pre_layers = nn.Sequential( 60 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(192), 62 | nn.ReLU(True), 63 | ) 64 | 65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 67 | 68 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 69 | 70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 75 | 76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 78 | 79 | self.avgpool = nn.AvgPool2d(8, stride=1) 80 | self.linear = nn.Linear(1024, 10) 81 | 82 | def forward(self, x): 83 | out = self.pre_layers(x) 84 | out = self.a3(out) 85 | out = self.b3(out) 86 | out = self.maxpool(out) 87 | out = self.a4(out) 88 | out = self.b4(out) 89 | out = self.c4(out) 90 | out = self.d4(out) 91 | out = self.e4(out) 92 | out = self.maxpool(out) 93 | out = self.a5(out) 94 | out = self.b5(out) 95 | out = self.avgpool(out) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def test(): 102 | net = GoogLeNet() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y.size()) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /code/train_generator/models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Inception(nn.Module): 8 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 9 | super(Inception, self).__init__() 10 | # 1x1 conv branch 11 | self.b1 = nn.Sequential( 12 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 13 | nn.BatchNorm2d(n1x1), 14 | nn.ReLU(True), 15 | ) 16 | 17 | # 1x1 conv -> 3x3 conv branch 18 | self.b2 = nn.Sequential( 19 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 20 | nn.BatchNorm2d(n3x3red), 21 | nn.ReLU(True), 22 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(n3x3), 24 | nn.ReLU(True), 25 | ) 26 | 27 | # 1x1 conv -> 5x5 conv branch 28 | self.b3 = nn.Sequential( 29 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 30 | nn.BatchNorm2d(n5x5red), 31 | nn.ReLU(True), 32 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(n5x5), 34 | nn.ReLU(True), 35 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(n5x5), 37 | nn.ReLU(True), 38 | ) 39 | 40 | # 3x3 pool -> 1x1 conv branch 41 | self.b4 = nn.Sequential( 42 | nn.MaxPool2d(3, stride=1, padding=1), 43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 44 | nn.BatchNorm2d(pool_planes), 45 | nn.ReLU(True), 46 | ) 47 | 48 | def forward(self, x): 49 | y1 = self.b1(x) 50 | y2 = self.b2(x) 51 | y3 = self.b3(x) 52 | y4 = self.b4(x) 53 | return torch.cat([y1,y2,y3,y4], 1) 54 | 55 | 56 | class GoogLeNet(nn.Module): 57 | def __init__(self): 58 | super(GoogLeNet, self).__init__() 59 | self.pre_layers = nn.Sequential( 60 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(192), 62 | nn.ReLU(True), 63 | ) 64 | 65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 67 | 68 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 69 | 70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 75 | 76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 78 | 79 | self.avgpool = nn.AvgPool2d(8, stride=1) 80 | self.linear = nn.Linear(1024, 10) 81 | 82 | def forward(self, x): 83 | out = self.pre_layers(x) 84 | out = self.a3(out) 85 | out = self.b3(out) 86 | out = self.maxpool(out) 87 | out = self.a4(out) 88 | out = self.b4(out) 89 | out = self.c4(out) 90 | out = self.d4(out) 91 | out = self.e4(out) 92 | out = self.maxpool(out) 93 | out = self.a5(out) 94 | out = self.b5(out) 95 | out = self.avgpool(out) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def test(): 102 | net = GoogLeNet() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y.size()) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 42 | super(ResNeXt, self).__init__() 43 | self.cardinality = cardinality 44 | self.bottleneck_width = bottleneck_width 45 | self.in_planes = 64 46 | 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.layer1 = self._make_layer(num_blocks[0], 1) 50 | self.layer2 = self._make_layer(num_blocks[1], 2) 51 | self.layer3 = self._make_layer(num_blocks[2], 2) 52 | # self.layer4 = self._make_layer(num_blocks[3], 2) 53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 54 | 55 | def _make_layer(self, num_blocks, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for stride in strides: 59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 61 | # Increase bottleneck_width by 2 after each stage. 62 | self.bottleneck_width *= 2 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.layer1(out) 68 | out = self.layer2(out) 69 | out = self.layer3(out) 70 | # out = self.layer4(out) 71 | out = F.avg_pool2d(out, 8) 72 | out = out.view(out.size(0), -1) 73 | out = self.linear(out) 74 | return out 75 | 76 | 77 | def ResNeXt29_2x64d(): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 79 | 80 | def ResNeXt29_4x64d(): 81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 82 | 83 | def ResNeXt29_8x64d(): 84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 85 | 86 | def ResNeXt29_32x4d(): 87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 88 | 89 | def test_resnext(): 90 | net = ResNeXt29_2x64d() 91 | x = torch.randn(1,3,32,32) 92 | y = net(x) 93 | print(y.size()) 94 | 95 | # test_resnext() 96 | -------------------------------------------------------------------------------- /code/train_generator/models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 42 | super(ResNeXt, self).__init__() 43 | self.cardinality = cardinality 44 | self.bottleneck_width = bottleneck_width 45 | self.in_planes = 64 46 | 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.layer1 = self._make_layer(num_blocks[0], 1) 50 | self.layer2 = self._make_layer(num_blocks[1], 2) 51 | self.layer3 = self._make_layer(num_blocks[2], 2) 52 | # self.layer4 = self._make_layer(num_blocks[3], 2) 53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 54 | 55 | def _make_layer(self, num_blocks, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for stride in strides: 59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 61 | # Increase bottleneck_width by 2 after each stage. 62 | self.bottleneck_width *= 2 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.layer1(out) 68 | out = self.layer2(out) 69 | out = self.layer3(out) 70 | # out = self.layer4(out) 71 | out = F.avg_pool2d(out, 8) 72 | out = out.view(out.size(0), -1) 73 | out = self.linear(out) 74 | return out 75 | 76 | 77 | def ResNeXt29_2x64d(): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 79 | 80 | def ResNeXt29_4x64d(): 81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 82 | 83 | def ResNeXt29_8x64d(): 84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 85 | 86 | def ResNeXt29_32x4d(): 87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 88 | 89 | def test_resnext(): 90 | net = ResNeXt29_2x64d() 91 | x = torch.randn(1,3,32,32) 92 | y = net(x) 93 | print(y.size()) 94 | 95 | # test_resnext() 96 | -------------------------------------------------------------------------------- /code/train_student/models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 42 | super(ResNeXt, self).__init__() 43 | self.cardinality = cardinality 44 | self.bottleneck_width = bottleneck_width 45 | self.in_planes = 64 46 | 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.layer1 = self._make_layer(num_blocks[0], 1) 50 | self.layer2 = self._make_layer(num_blocks[1], 2) 51 | self.layer3 = self._make_layer(num_blocks[2], 2) 52 | # self.layer4 = self._make_layer(num_blocks[3], 2) 53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 54 | 55 | def _make_layer(self, num_blocks, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for stride in strides: 59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 61 | # Increase bottleneck_width by 2 after each stage. 62 | self.bottleneck_width *= 2 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.layer1(out) 68 | out = self.layer2(out) 69 | out = self.layer3(out) 70 | # out = self.layer4(out) 71 | out = F.avg_pool2d(out, 8) 72 | out = out.view(out.size(0), -1) 73 | out = self.linear(out) 74 | return out 75 | 76 | 77 | def ResNeXt29_2x64d(): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 79 | 80 | def ResNeXt29_4x64d(): 81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 82 | 83 | def ResNeXt29_8x64d(): 84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 85 | 86 | def ResNeXt29_32x4d(): 87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 88 | 89 | def test_resnext(): 90 | net = ResNeXt29_2x64d() 91 | x = torch.randn(1,3,32,32) 92 | y = net(x) 93 | print(y.size()) 94 | 95 | # test_resnext() 96 | -------------------------------------------------------------------------------- /models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 9 | super(Bottleneck, self).__init__() 10 | self.out_planes = out_planes 11 | self.dense_depth = dense_depth 12 | 13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 16 | self.bn2 = nn.BatchNorm2d(in_planes) 17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 19 | 20 | self.shortcut = nn.Sequential() 21 | if first_layer: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(out_planes+dense_depth) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | x = self.shortcut(x) 32 | d = self.out_planes 33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 34 | out = F.relu(out) 35 | return out 36 | 37 | 38 | class DPN(nn.Module): 39 | def __init__(self, cfg): 40 | super(DPN, self).__init__() 41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 43 | 44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(64) 46 | self.last_planes = 64 47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 52 | 53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 54 | strides = [stride] + [1]*(num_blocks-1) 55 | layers = [] 56 | for i,stride in enumerate(strides): 57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 58 | self.last_planes = out_planes + (i+2) * dense_depth 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = self.layer1(out) 64 | out = self.layer2(out) 65 | out = self.layer3(out) 66 | out = self.layer4(out) 67 | out = F.avg_pool2d(out, 4) 68 | out = out.view(out.size(0), -1) 69 | out = self.linear(out) 70 | return out 71 | 72 | 73 | def DPN26(): 74 | cfg = { 75 | 'in_planes': (96,192,384,768), 76 | 'out_planes': (256,512,1024,2048), 77 | 'num_blocks': (2,2,2,2), 78 | 'dense_depth': (16,32,24,128) 79 | } 80 | return DPN(cfg) 81 | 82 | def DPN92(): 83 | cfg = { 84 | 'in_planes': (96,192,384,768), 85 | 'out_planes': (256,512,1024,2048), 86 | 'num_blocks': (3,4,20,3), 87 | 'dense_depth': (16,32,24,128) 88 | } 89 | return DPN(cfg) 90 | 91 | 92 | def test(): 93 | net = DPN92() 94 | x = torch.randn(1,3,32,32) 95 | y = net(x) 96 | print(y) 97 | 98 | # test() 99 | -------------------------------------------------------------------------------- /run_cifar40_classes_resnet.sh: -------------------------------------------------------------------------------- 1 | 2 | # train dcgan for cifar-100 40 classes 3 | 4 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/dcgan.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/alexnet/cifar_40/dcgan/ --manualSeed 108 --niter 200 --batchSize 64 --network resnet --proxy_ds_name 40_class --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --num_classes 10 5 | 6 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --dcgan_netG ./cifar100_run_models/alexnet/cifar_40/dcgan/netG_epoch_199.pth --dcgan_out ./val_data/dcgan_val_data_cifar_40_resnet.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --network resnet 7 | 8 | 9 | # Hard label runs 10 | 11 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/train_student.py --from_scratch --dcgan_netG ./cifar100_run_models/alexnet/cifar_40/dcgan/netG_epoch_199.pth --count 1 --network resnet --dcgan_data --dcgan_data_ratio 0.5 --proxy_data --proxy_data_ratio 1 --pad_crop --name proxy_dcgan_45k_40_class_rand --proxy_ds_name 40_class --val_data_dcgan ./val_data/dcgan_val_data_cifar_40_resnet.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --true_dataset cifar10 --num_classes 10 --proxy_dataset cifar100 12 | 13 | 14 | # ./train_student/checkpoint/resnet/cifar10_cifar100/1/last_epoch_best_student_model_proxy_dcgan_45k_40_class_rand.pth 15 | 16 | 17 | # Run div gan 18 | 19 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_gen.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/resnet/out_step2_0_10/40_class/2/ --manualSeed 108 --niter 100 --batchSize 64 --netG ./cifar100_run_models/alexnet/cifar_40/dcgan/netG_epoch_199.pth --student_path ./train_student/checkpoint/resnet/cifar10_cifar100/1/last_epoch_best_student_model_proxy_dcgan_45k_40_class_rand.pth --network resnet --c_l 0 --d_l 10 --proxy_ds_name 40_class --true_dataset cifar10 --num_classes 10 20 | 21 | 22 | # DivGAN + Random Crop (student trained from scratch on 0.5*Proxy + 0.5*DivGAN ) 23 | CUDA_VISIBLE_DEVICES=0 python code/train_student/train_student.py --div_gan_netG ./cifar100_run_models/resnet/out_step2_0_10/40_class/2/netG_epoch_99.pth --dcgan_netG ./cifar100_run_models/alexnet/cifar_40/dcgan/netG_epoch_199.pth --count 2 --network resnet --pad_crop --div_gan_data --div_gan_data_ratio 0.5 --proxy_data --proxy_data_ratio 1 --from_scratch --name from_scratch_div_gan_05_40_class --proxy_ds_name 40_class --val_data_dcgan ./val_data/dcgan_val_data_cifar_40_resnet.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --proxy_dataset cifar100 --true_dataset cifar10 --num_classes 10 24 | 25 | 26 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --div_gan_netG ./cifar100_run_models/resnet/out_step2_0_10/40_class/2/netG_epoch_99.pth --network resnet --div_gan_out ./val_data/degan_val_data_cifar_40_resnet.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth 27 | 28 | 29 | # ./train_student/checkpoint/resnet/cifar10_cifar100_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_40_class.pth 30 | 31 | 32 | # Alternate training 33 | 34 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_generator_clone.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/resnet/out_step2/cifar_40/0_500/ --manualSeed 108 --niter 400 --batchSize 64 --netG ./cifar100_run_models/resnet/out_step2_0_10/40_class/2/netG_epoch_99.pth --student_path ./train_student/checkpoint/resnet/cifar10_cifar100_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_40_class.pth --network resnet --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --warmup --name warmup_altr_0_500_1_p10 --auto-augment --c_l 0 --d_l 500 --proxy_ds_name 40_class --val_data_dcgan ./val_data/dcgan_val_data_cifar_40_resnet.pkl --val_data_degan ./val_data/degan_val_data_cifar_40_resnet.pkl --true_dataset cifar10 --num_classes 10 -------------------------------------------------------------------------------- /code/train_generator/models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 9 | super(Bottleneck, self).__init__() 10 | self.out_planes = out_planes 11 | self.dense_depth = dense_depth 12 | 13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 16 | self.bn2 = nn.BatchNorm2d(in_planes) 17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 19 | 20 | self.shortcut = nn.Sequential() 21 | if first_layer: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(out_planes+dense_depth) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | x = self.shortcut(x) 32 | d = self.out_planes 33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 34 | out = F.relu(out) 35 | return out 36 | 37 | 38 | class DPN(nn.Module): 39 | def __init__(self, cfg): 40 | super(DPN, self).__init__() 41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 43 | 44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(64) 46 | self.last_planes = 64 47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 52 | 53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 54 | strides = [stride] + [1]*(num_blocks-1) 55 | layers = [] 56 | for i,stride in enumerate(strides): 57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 58 | self.last_planes = out_planes + (i+2) * dense_depth 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = self.layer1(out) 64 | out = self.layer2(out) 65 | out = self.layer3(out) 66 | out = self.layer4(out) 67 | out = F.avg_pool2d(out, 4) 68 | out = out.view(out.size(0), -1) 69 | out = self.linear(out) 70 | return out 71 | 72 | 73 | def DPN26(): 74 | cfg = { 75 | 'in_planes': (96,192,384,768), 76 | 'out_planes': (256,512,1024,2048), 77 | 'num_blocks': (2,2,2,2), 78 | 'dense_depth': (16,32,24,128) 79 | } 80 | return DPN(cfg) 81 | 82 | def DPN92(): 83 | cfg = { 84 | 'in_planes': (96,192,384,768), 85 | 'out_planes': (256,512,1024,2048), 86 | 'num_blocks': (3,4,20,3), 87 | 'dense_depth': (16,32,24,128) 88 | } 89 | return DPN(cfg) 90 | 91 | 92 | def test(): 93 | net = DPN92() 94 | x = torch.randn(1,3,32,32) 95 | y = net(x) 96 | print(y) 97 | 98 | # test() 99 | -------------------------------------------------------------------------------- /code/train_student/models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 9 | super(Bottleneck, self).__init__() 10 | self.out_planes = out_planes 11 | self.dense_depth = dense_depth 12 | 13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 16 | self.bn2 = nn.BatchNorm2d(in_planes) 17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 19 | 20 | self.shortcut = nn.Sequential() 21 | if first_layer: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(out_planes+dense_depth) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | x = self.shortcut(x) 32 | d = self.out_planes 33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 34 | out = F.relu(out) 35 | return out 36 | 37 | 38 | class DPN(nn.Module): 39 | def __init__(self, cfg): 40 | super(DPN, self).__init__() 41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 43 | 44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(64) 46 | self.last_planes = 64 47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 52 | 53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 54 | strides = [stride] + [1]*(num_blocks-1) 55 | layers = [] 56 | for i,stride in enumerate(strides): 57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 58 | self.last_planes = out_planes + (i+2) * dense_depth 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = self.layer1(out) 64 | out = self.layer2(out) 65 | out = self.layer3(out) 66 | out = self.layer4(out) 67 | out = F.avg_pool2d(out, 4) 68 | out = out.view(out.size(0), -1) 69 | out = self.linear(out) 70 | return out 71 | 72 | 73 | def DPN26(): 74 | cfg = { 75 | 'in_planes': (96,192,384,768), 76 | 'out_planes': (256,512,1024,2048), 77 | 'num_blocks': (2,2,2,2), 78 | 'dense_depth': (16,32,24,128) 79 | } 80 | return DPN(cfg) 81 | 82 | def DPN92(): 83 | cfg = { 84 | 'in_planes': (96,192,384,768), 85 | 'out_planes': (256,512,1024,2048), 86 | 'num_blocks': (3,4,20,3), 87 | 'dense_depth': (16,32,24,128) 88 | } 89 | return DPN(cfg) 90 | 91 | 92 | def test(): 93 | net = DPN92() 94 | x = torch.randn(1,3,32,32) 95 | y = net(x) 96 | print(y) 97 | 98 | # test() 99 | -------------------------------------------------------------------------------- /run_cifar40_classes_alexnet.sh: -------------------------------------------------------------------------------- 1 | # CIFAR-100 10 class ALEXNET runs 2 | 3 | # train dcgan for cifar-100 40 classes 4 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/dcgan.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/alexnet/cifar_40/dcgan/ --manualSeed 108 --niter 200 --batchSize 64 --network alexnet --proxy_ds_name 40_class --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --num_classes 10 5 | 6 | #mkdir val_data 7 | 8 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --dcgan_netG ./cifar100_run_models/alexnet/cifar_40/dcgan/netG_epoch_199.pth --dcgan_out ./val_data/dcgan_val_data_cifar_40_alexnet.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --network alexnet 9 | 10 | 11 | # Hard label runs 12 | 13 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/train_student.py --from_scratch --dcgan_netG ./cifar100_run_models/alexnet/cifar_40/dcgan/netG_epoch_199.pth --count 1 --network alexnet --dcgan_data --dcgan_data_ratio 0.5 --proxy_data --proxy_data_ratio 1 --pad_crop --name proxy_dcgan_45k_40_class_rand --proxy_ds_name 40_class --val_data_dcgan ./val_data/dcgan_val_data_cifar_40_alexnet.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --true_dataset cifar10 --num_classes 10 --proxy_dataset cifar100 14 | 15 | # ./train_student/checkpoint/alexnet/cifar10_cifar100/1/last_epoch_best_student_model_proxy_dcgan_45k_40_class_rand.pth 16 | 17 | 18 | # Run div gan 19 | 20 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_gen.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/alexnet/out_step2_0_10/40_class/2/ --manualSeed 108 --niter 100 --batchSize 64 --netG ./cifar100_run_models/alexnet/cifar_40/dcgan/netG_epoch_199.pth --student_path ./train_student/checkpoint/alexnet/cifar10_cifar100/1/last_epoch_best_student_model_proxy_dcgan_45k_40_class_rand.pth --network alexnet --c_l 0 --d_l 10 --proxy_ds_name 40_class --true_dataset cifar10 --num_classes 10 21 | 22 | # DivGAN + Random Crop (student trained from scratch on 0.5*Proxy + 0.5*DivGAN ) 23 | CUDA_VISIBLE_DEVICES=0 python code/train_student/train_student.py --div_gan_netG ./cifar100_run_models/alexnet/out_step2_0_10/40_class/2/netG_epoch_99.pth --dcgan_netG ./cifar100_run_models/alexnet/cifar_40/dcgan/netG_epoch_199.pth --count 2 --network alexnet --pad_crop --div_gan_data --div_gan_data_ratio 0.5 --proxy_data --proxy_data_ratio 1 --from_scratch --name from_scratch_div_gan_05_40_class --proxy_ds_name 40_class --val_data_dcgan ./val_data/dcgan_val_data_cifar_40_alexnet.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --proxy_dataset cifar100 --true_dataset cifar10 --num_classes 10 24 | 25 | 26 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --div_gan_netG ./cifar100_run_models/alexnet/out_step2_0_10/40_class/2/netG_epoch_99.pth --network alexnet --div_gan_out ./val_data/div_gan_val_data_cifar_40_alexnet.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth 27 | 28 | 29 | # ./train_student/checkpoint/alexnet/cifar10_cifar100_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_40_class.pth 30 | 31 | # Alternate training 32 | 33 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_generator_clone.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/alexnet/out_step2/cifar_40/0_500/ --manualSeed 108 --niter 400 --batchSize 64 --netG ./cifar100_run_models/alexnet/out_step2_0_10/40_class/2/netG_epoch_99.pth --student_path ./train_student/checkpoint/alexnet/cifar10_cifar100_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_40_class.pth --network alexnet --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --warmup --name warmup_altr_0_500_1_p10 --auto-augment --c_l 0 --d_l 500 --proxy_ds_name 40_class --val_data_dcgan ./val_data/dcgan_val_data_cifar_40_alexnet.pkl --val_data_degan ./val_data/div_gan_val_data_cifar_40_alexnet.pkl --true_dataset cifar10 --num_classes 10 -------------------------------------------------------------------------------- /run_cifar10_rand_class_resnet.sh: -------------------------------------------------------------------------------- 1 | # train dcgan for cifar-100 10 classes 2 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/dcgan.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/resnet/cifar_10_rand/dcgan/ --manualSeed 108 --niter 200 --batchSize 64 --network resnet --proxy_ds_name 10_class_rand --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --num_classes 10 3 | 4 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --dcgan_netG ./cifar100_run_models/resnet/cifar_10_rand/dcgan/netG_epoch_199.pth --dcgan_out ./val_data/dcgan_val_data_cifar10_rand_resnet.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --network resnet 5 | 6 | 7 | # Hard label runs 8 | 9 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/train_student.py --from_scratch --dcgan_netG ./cifar100_run_models/resnet/cifar_10_rand/dcgan/netG_epoch_199.pth --count 1 --network resnet --dcgan_data --dcgan_data_ratio 0.8 --proxy_data --proxy_data_ratio 1 --pad_crop --name proxy_dcgan_45k_10_class_rand --proxy_ds_name 10_class_rand --val_data_dcgan ./val_data/dcgan_val_data_cifar10_rand_resnet.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --true_dataset cifar10 --num_classes 10 --proxy_dataset cifar100 10 | 11 | 12 | # ./train_student/checkpoint/resnet/cifar10_cifar100/1/last_epoch_best_student_model_proxy_dcgan_45k_10_class_rand.pth 13 | 14 | 15 | # Run degan 16 | 17 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_gen.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/resnet/out_step2_0_10/10_class_rand/2/ --manualSeed 108 --niter 100 --batchSize 64 --netG ./cifar100_run_models/resnet/cifar_10_rand/dcgan/netG_epoch_199.pth --student_path ./train_student/checkpoint/resnet/cifar10_cifar100/1/last_epoch_best_student_model_proxy_dcgan_45k_10_class_rand.pth --network resnet --c_l 0 --d_l 10 --proxy_ds_name 10_class_rand --true_dataset cifar10 --num_classes 10 18 | 19 | 20 | # DivGAN + Random Crop (student trained from scratch on 0.5*Proxy + 0.5*DivGAN ) 21 | CUDA_VISIBLE_DEVICES=0 python code/train_student/train_student.py --div_gan_netG ./cifar100_run_models/resnet/out_step2_0_10/10_class_rand/2/netG_epoch_99.pth --dcgan_netG ./cifar100_run_models/resnet/cifar_10_rand/dcgan/netG_epoch_199.pth --count 2 --network resnet --pad_crop --div_gan_data --div_gan_data_ratio 0.8 --proxy_data --proxy_data_ratio 1 --from_scratch --name from_scratch_div_gan_05_10_class_rand --proxy_ds_name 10_class_rand --val_data_dcgan ./val_data/dcgan_val_data_cifar10_rand_resnet.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --proxy_dataset cifar100 --true_dataset cifar10 --num_classes 10 22 | 23 | 24 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --div_gan_netG ./cifar100_run_models/resnet/out_step2_0_10/10_class_rand/2/netG_epoch_99.pth --network resnet --div_gan_out ./val_data/degan_val_data_cifar10_rand_resnet.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth 25 | 26 | 27 | # ./train_student/checkpoint/resnet/cifar10_cifar100_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_10_class_rand.pth 28 | 29 | # Alternate training 30 | 31 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_generator_clone.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/resnet/out_step2/cifar_10_rand/0_500/ --manualSeed 108 --niter 800 --batchSize 64 --netG ./cifar100_run_models/resnet/out_step2_0_10/10_class_rand/2/netG_epoch_99.pth --student_path ./train_student/checkpoint/resnet/cifar10_cifar100_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_10_class_rand.pth --network resnet --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --warmup --name warmup_altr_0_500_1_p10 --auto-augment --c_l 0 --d_l 500 --proxy_ds_name 10_class_rand --val_data_dcgan ./val_data/dcgan_val_data_cifar10_rand_resnet.pkl --val_data_degan ./val_data/degan_val_data_cifar10_rand_resnet.pkl --true_dataset cifar10 --num_classes 10 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | 45 | _, term_width = os.popen('stty size', 'r').read().split() 46 | term_width = int(term_width) 47 | 48 | TOTAL_BAR_LENGTH = 65. 49 | last_time = time.time() 50 | begin_time = last_time 51 | def progress_bar(current, total, msg=None): 52 | global last_time, begin_time 53 | if current == 0: 54 | begin_time = time.time() # Reset for new bar. 55 | 56 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 | 59 | sys.stdout.write(' [') 60 | for i in range(cur_len): 61 | sys.stdout.write('=') 62 | sys.stdout.write('>') 63 | for i in range(rest_len): 64 | sys.stdout.write('.') 65 | sys.stdout.write(']') 66 | 67 | cur_time = time.time() 68 | step_time = cur_time - last_time 69 | last_time = cur_time 70 | tot_time = cur_time - begin_time 71 | 72 | L = [] 73 | L.append(' Step: %s' % format_time(step_time)) 74 | L.append(' | Tot: %s' % format_time(tot_time)) 75 | if msg: 76 | L.append(' | ' + msg) 77 | 78 | msg = ''.join(L) 79 | sys.stdout.write(msg) 80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 | sys.stdout.write(' ') 82 | 83 | # Go back to the center of the bar. 84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 | sys.stdout.write('\b') 86 | sys.stdout.write(' %d/%d ' % (current+1, total)) 87 | 88 | if current < total-1: 89 | sys.stdout.write('\r') 90 | else: 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | def format_time(seconds): 95 | days = int(seconds / 3600/24) 96 | seconds = seconds - days*3600*24 97 | hours = int(seconds / 3600) 98 | seconds = seconds - hours*3600 99 | minutes = int(seconds / 60) 100 | seconds = seconds - minutes*60 101 | secondsf = int(seconds) 102 | seconds = seconds - secondsf 103 | millis = int(seconds*1000) 104 | 105 | f = '' 106 | i = 1 107 | if days > 0: 108 | f += str(days) + 'D' 109 | i += 1 110 | if hours > 0 and i <= 2: 111 | f += str(hours) + 'h' 112 | i += 1 113 | if minutes > 0 and i <= 2: 114 | f += str(minutes) + 'm' 115 | i += 1 116 | if secondsf > 0 and i <= 2: 117 | f += str(secondsf) + 's' 118 | i += 1 119 | if millis > 0 and i <= 2: 120 | f += str(millis) + 'ms' 121 | i += 1 122 | if f == '': 123 | f = '0ms' 124 | return f 125 | -------------------------------------------------------------------------------- /dcgan_model.py: -------------------------------------------------------------------------------- 1 | # Network of DCGAN 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils import spectral_norm 5 | 6 | 7 | class Generator(nn.Module): 8 | def __init__(self, ngpu, nc=3, nz=100, ngf=64): 9 | super(Generator, self).__init__() 10 | self.ngpu = ngpu 11 | self.main = nn.Sequential( 12 | nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), 13 | nn.BatchNorm2d(ngf * 8), 14 | nn.ReLU(True), 15 | 16 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 17 | nn.BatchNorm2d(ngf * 4), 18 | nn.ReLU(True), 19 | 20 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 21 | nn.BatchNorm2d(ngf * 2), 22 | nn.ReLU(True), 23 | 24 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 25 | nn.BatchNorm2d(ngf), 26 | nn.ReLU(True), 27 | 28 | nn.ConvTranspose2d( ngf, nc, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.Tanh() 30 | ) 31 | 32 | def forward(self, input): 33 | if input.is_cuda and self.ngpu > 1: 34 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 35 | else: 36 | output = self.main(input) 37 | return output 38 | 39 | 40 | class Discriminator(nn.Module): 41 | def __init__(self, ngpu, nc=3, ndf=64): 42 | super(Discriminator, self).__init__() 43 | self.ngpu = ngpu 44 | self.main = nn.Sequential( 45 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 46 | nn.LeakyReLU(0.2, inplace=True), 47 | 48 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 49 | nn.BatchNorm2d(ndf * 2), 50 | nn.LeakyReLU(0.2, inplace=True), 51 | 52 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 53 | nn.BatchNorm2d(ndf * 4), 54 | nn.LeakyReLU(0.2, inplace=True), 55 | 56 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 57 | nn.BatchNorm2d(ndf * 8), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | 60 | nn.Conv2d(ndf * 8, 1, 2, 2, 0, bias=False), 61 | nn.Sigmoid() 62 | ) 63 | 64 | def forward(self, input): 65 | if input.is_cuda and self.ngpu > 1: 66 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 67 | else: 68 | output = self.main(input) 69 | 70 | return output.view(-1, 1).squeeze(1) 71 | 72 | 73 | 74 | class Discriminator_SNGAN(nn.Module): 75 | def __init__(self, ngpu, nc=3, ndf=64): 76 | super(Discriminator_SNGAN, self).__init__() 77 | self.ngpu = ngpu 78 | self.main = nn.Sequential( 79 | 80 | spectral_norm(nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), eps=1e-6), 81 | nn.LeakyReLU(0.2, inplace=True), 82 | 83 | spectral_norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), eps=1e-6), 84 | #nn.BatchNorm2d(ndf * 2), 85 | nn.LeakyReLU(0.2, inplace=True), 86 | 87 | spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), eps=1e-6), 88 | #nn.BatchNorm2d(ndf * 4), 89 | nn.LeakyReLU(0.2, inplace=True), 90 | 91 | spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), eps=1e-6), 92 | #nn.BatchNorm2d(ndf * 8), 93 | nn.LeakyReLU(0.2, inplace=True), 94 | 95 | spectral_norm(nn.Conv2d(ndf * 8, 1, 2, 2, 0, bias=False), eps=1e-6), 96 | nn.Sigmoid() 97 | ) 98 | 99 | def forward(self, input): 100 | if input.is_cuda and self.ngpu > 1: 101 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 102 | else: 103 | output = self.main(input) 104 | 105 | return output.view(-1, 1).squeeze(1) 106 | 107 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 38 | super(DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | def DenseNet121(): 87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 88 | 89 | def DenseNet169(): 90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | def DenseNet201(): 93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | def DenseNet161(): 96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | def densenet_cifar(): 99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | def test(): 102 | net = densenet_cifar() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N,C,H,W = x.size() 18 | g = self.groups 19 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 20 | 21 | 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_planes, out_planes, stride, groups): 24 | super(Bottleneck, self).__init__() 25 | self.stride = stride 26 | 27 | mid_planes = out_planes/4 28 | g = 1 if in_planes==24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 48 | return out 49 | 50 | 51 | class ShuffleNet(nn.Module): 52 | def __init__(self, cfg): 53 | super(ShuffleNet, self).__init__() 54 | out_planes = cfg['out_planes'] 55 | num_blocks = cfg['num_blocks'] 56 | groups = cfg['groups'] 57 | 58 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(24) 60 | self.in_planes = 24 61 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 62 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 63 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 64 | self.linear = nn.Linear(out_planes[2], 10) 65 | 66 | def _make_layer(self, out_planes, num_blocks, groups): 67 | layers = [] 68 | for i in range(num_blocks): 69 | stride = 2 if i == 0 else 1 70 | cat_planes = self.in_planes if i == 0 else 0 71 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 72 | self.in_planes = out_planes 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = self.layer1(out) 78 | out = self.layer2(out) 79 | out = self.layer3(out) 80 | out = F.avg_pool2d(out, 4) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | return out 84 | 85 | 86 | def ShuffleNetG2(): 87 | cfg = { 88 | 'out_planes': [200,400,800], 89 | 'num_blocks': [4,8,4], 90 | 'groups': 2 91 | } 92 | return ShuffleNet(cfg) 93 | 94 | def ShuffleNetG3(): 95 | cfg = { 96 | 'out_planes': [240,480,960], 97 | 'num_blocks': [4,8,4], 98 | 'groups': 3 99 | } 100 | return ShuffleNet(cfg) 101 | 102 | 103 | def test(): 104 | net = ShuffleNetG2() 105 | x = torch.randn(1,3,32,32) 106 | y = net(x) 107 | print(y) 108 | 109 | # test() 110 | -------------------------------------------------------------------------------- /code/train_student/models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 38 | super(DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | def DenseNet121(): 87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 88 | 89 | def DenseNet169(): 90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | def DenseNet201(): 93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | def DenseNet161(): 96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | def densenet_cifar(): 99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | def test(): 102 | net = densenet_cifar() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /code/train_student/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | 45 | _, term_width = os.popen('stty size', 'r').read().split() 46 | term_width = int(term_width) 47 | 48 | TOTAL_BAR_LENGTH = 65. 49 | last_time = time.time() 50 | begin_time = last_time 51 | def progress_bar(current, total, msg=None): 52 | global last_time, begin_time 53 | if current == 0: 54 | begin_time = time.time() # Reset for new bar. 55 | 56 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 | 59 | sys.stdout.write(' [') 60 | for i in range(cur_len): 61 | sys.stdout.write('=') 62 | sys.stdout.write('>') 63 | for i in range(rest_len): 64 | sys.stdout.write('.') 65 | sys.stdout.write(']') 66 | 67 | cur_time = time.time() 68 | step_time = cur_time - last_time 69 | last_time = cur_time 70 | tot_time = cur_time - begin_time 71 | 72 | L = [] 73 | L.append(' Step: %s' % format_time(step_time)) 74 | L.append(' | Tot: %s' % format_time(tot_time)) 75 | if msg: 76 | L.append(' | ' + msg) 77 | 78 | msg = ''.join(L) 79 | sys.stdout.write(msg) 80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 | sys.stdout.write(' ') 82 | 83 | # Go back to the center of the bar. 84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 | sys.stdout.write('\b') 86 | sys.stdout.write(' %d/%d ' % (current+1, total)) 87 | 88 | if current < total-1: 89 | sys.stdout.write('\r') 90 | else: 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | def format_time(seconds): 95 | days = int(seconds / 3600/24) 96 | seconds = seconds - days*3600*24 97 | hours = int(seconds / 3600) 98 | seconds = seconds - hours*3600 99 | minutes = int(seconds / 60) 100 | seconds = seconds - minutes*60 101 | secondsf = int(seconds) 102 | seconds = seconds - secondsf 103 | millis = int(seconds*1000) 104 | 105 | f = '' 106 | i = 1 107 | if days > 0: 108 | f += str(days) + 'D' 109 | i += 1 110 | if hours > 0 and i <= 2: 111 | f += str(hours) + 'h' 112 | i += 1 113 | if minutes > 0 and i <= 2: 114 | f += str(minutes) + 'm' 115 | i += 1 116 | if secondsf > 0 and i <= 2: 117 | f += str(secondsf) + 's' 118 | i += 1 119 | if millis > 0 and i <= 2: 120 | f += str(millis) + 'ms' 121 | i += 1 122 | if f == '': 123 | f = '0ms' 124 | return f 125 | -------------------------------------------------------------------------------- /code/train_generator/dcgan_model.py: -------------------------------------------------------------------------------- 1 | # Network of DCGAN 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils import spectral_norm 5 | 6 | 7 | class Generator(nn.Module): 8 | def __init__(self, ngpu, nc=3, nz=100, ngf=64): 9 | super(Generator, self).__init__() 10 | self.ngpu = ngpu 11 | self.main = nn.Sequential( 12 | nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), 13 | nn.BatchNorm2d(ngf * 8), 14 | nn.ReLU(True), 15 | 16 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 17 | nn.BatchNorm2d(ngf * 4), 18 | nn.ReLU(True), 19 | 20 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 21 | nn.BatchNorm2d(ngf * 2), 22 | nn.ReLU(True), 23 | 24 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 25 | nn.BatchNorm2d(ngf), 26 | nn.ReLU(True), 27 | 28 | nn.ConvTranspose2d( ngf, nc, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.Tanh() 30 | ) 31 | 32 | def forward(self, input): 33 | if input.is_cuda and self.ngpu > 1: 34 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 35 | else: 36 | output = self.main(input) 37 | return output 38 | 39 | 40 | class Discriminator(nn.Module): 41 | def __init__(self, ngpu, nc=3, ndf=64): 42 | super(Discriminator, self).__init__() 43 | self.ngpu = ngpu 44 | self.main = nn.Sequential( 45 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 46 | nn.LeakyReLU(0.2, inplace=True), 47 | 48 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 49 | nn.BatchNorm2d(ndf * 2), 50 | nn.LeakyReLU(0.2, inplace=True), 51 | 52 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 53 | nn.BatchNorm2d(ndf * 4), 54 | nn.LeakyReLU(0.2, inplace=True), 55 | 56 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 57 | nn.BatchNorm2d(ndf * 8), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | 60 | nn.Conv2d(ndf * 8, 1, 2, 2, 0, bias=False), 61 | nn.Sigmoid() 62 | ) 63 | 64 | def forward(self, input): 65 | if input.is_cuda and self.ngpu > 1: 66 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 67 | else: 68 | output = self.main(input) 69 | 70 | return output.view(-1, 1).squeeze(1) 71 | 72 | 73 | 74 | class Discriminator_SNGAN(nn.Module): 75 | def __init__(self, ngpu, nc=3, ndf=64): 76 | super(Discriminator_SNGAN, self).__init__() 77 | self.ngpu = ngpu 78 | self.main = nn.Sequential( 79 | 80 | spectral_norm(nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), eps=1e-6), 81 | nn.LeakyReLU(0.2, inplace=True), 82 | 83 | spectral_norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), eps=1e-6), 84 | #nn.BatchNorm2d(ndf * 2), 85 | nn.LeakyReLU(0.2, inplace=True), 86 | 87 | spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), eps=1e-6), 88 | #nn.BatchNorm2d(ndf * 4), 89 | nn.LeakyReLU(0.2, inplace=True), 90 | 91 | spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), eps=1e-6), 92 | #nn.BatchNorm2d(ndf * 8), 93 | nn.LeakyReLU(0.2, inplace=True), 94 | 95 | spectral_norm(nn.Conv2d(ndf * 8, 1, 2, 2, 0, bias=False), eps=1e-6), 96 | nn.Sigmoid() 97 | ) 98 | 99 | def forward(self, input): 100 | if input.is_cuda and self.ngpu > 1: 101 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 102 | else: 103 | output = self.main(input) 104 | 105 | return output.view(-1, 1).squeeze(1) 106 | 107 | -------------------------------------------------------------------------------- /code/train_generator/models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 38 | super(DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | def DenseNet121(): 87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 88 | 89 | def DenseNet169(): 90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | def DenseNet201(): 93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | def DenseNet161(): 96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | def densenet_cifar(): 99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | def test(): 102 | net = densenet_cifar() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /code/train_student/dcgan_model.py: -------------------------------------------------------------------------------- 1 | # Network of DCGAN 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils import spectral_norm 5 | 6 | 7 | class Generator(nn.Module): 8 | def __init__(self, ngpu, nc=3, nz=100, ngf=64): 9 | super(Generator, self).__init__() 10 | self.ngpu = ngpu 11 | self.main = nn.Sequential( 12 | nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), 13 | nn.BatchNorm2d(ngf * 8), 14 | nn.ReLU(True), 15 | 16 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 17 | nn.BatchNorm2d(ngf * 4), 18 | nn.ReLU(True), 19 | 20 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 21 | nn.BatchNorm2d(ngf * 2), 22 | nn.ReLU(True), 23 | 24 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 25 | nn.BatchNorm2d(ngf), 26 | nn.ReLU(True), 27 | 28 | nn.ConvTranspose2d( ngf, nc, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.Tanh() 30 | ) 31 | 32 | def forward(self, input): 33 | if input.is_cuda and self.ngpu > 1: 34 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 35 | else: 36 | output = self.main(input) 37 | return output 38 | 39 | 40 | class Discriminator(nn.Module): 41 | def __init__(self, ngpu, nc=3, ndf=64): 42 | super(Discriminator, self).__init__() 43 | self.ngpu = ngpu 44 | self.main = nn.Sequential( 45 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 46 | nn.LeakyReLU(0.2, inplace=True), 47 | 48 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 49 | nn.BatchNorm2d(ndf * 2), 50 | nn.LeakyReLU(0.2, inplace=True), 51 | 52 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 53 | nn.BatchNorm2d(ndf * 4), 54 | nn.LeakyReLU(0.2, inplace=True), 55 | 56 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 57 | nn.BatchNorm2d(ndf * 8), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | 60 | nn.Conv2d(ndf * 8, 1, 2, 2, 0, bias=False), 61 | nn.Sigmoid() 62 | ) 63 | 64 | def forward(self, input): 65 | if input.is_cuda and self.ngpu > 1: 66 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 67 | else: 68 | output = self.main(input) 69 | 70 | return output.view(-1, 1).squeeze(1) 71 | 72 | 73 | 74 | class Discriminator_SNGAN(nn.Module): 75 | def __init__(self, ngpu, nc=3, ndf=64): 76 | super(Discriminator_SNGAN, self).__init__() 77 | self.ngpu = ngpu 78 | self.main = nn.Sequential( 79 | 80 | spectral_norm(nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), eps=1e-6), 81 | nn.LeakyReLU(0.2, inplace=True), 82 | 83 | spectral_norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), eps=1e-6), 84 | #nn.BatchNorm2d(ndf * 2), 85 | nn.LeakyReLU(0.2, inplace=True), 86 | 87 | spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), eps=1e-6), 88 | #nn.BatchNorm2d(ndf * 4), 89 | nn.LeakyReLU(0.2, inplace=True), 90 | 91 | spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), eps=1e-6), 92 | #nn.BatchNorm2d(ndf * 8), 93 | nn.LeakyReLU(0.2, inplace=True), 94 | 95 | spectral_norm(nn.Conv2d(ndf * 8, 1, 2, 2, 0, bias=False), eps=1e-6), 96 | nn.Sigmoid() 97 | ) 98 | 99 | def forward(self, input): 100 | if input.is_cuda and self.ngpu > 1: 101 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 102 | else: 103 | output = self.main(input) 104 | 105 | return output.view(-1, 1).squeeze(1) 106 | 107 | -------------------------------------------------------------------------------- /code/train_student/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N,C,H,W = x.size() 18 | g = self.groups 19 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 20 | 21 | 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_planes, out_planes, stride, groups): 24 | super(Bottleneck, self).__init__() 25 | self.stride = stride 26 | 27 | mid_planes = out_planes/4 28 | g = 1 if in_planes==24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 48 | return out 49 | 50 | 51 | class ShuffleNet(nn.Module): 52 | def __init__(self, cfg): 53 | super(ShuffleNet, self).__init__() 54 | out_planes = cfg['out_planes'] 55 | num_blocks = cfg['num_blocks'] 56 | groups = cfg['groups'] 57 | 58 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(24) 60 | self.in_planes = 24 61 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 62 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 63 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 64 | self.linear = nn.Linear(out_planes[2], 10) 65 | 66 | def _make_layer(self, out_planes, num_blocks, groups): 67 | layers = [] 68 | for i in range(num_blocks): 69 | stride = 2 if i == 0 else 1 70 | cat_planes = self.in_planes if i == 0 else 0 71 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 72 | self.in_planes = out_planes 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = self.layer1(out) 78 | out = self.layer2(out) 79 | out = self.layer3(out) 80 | out = F.avg_pool2d(out, 4) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | return out 84 | 85 | 86 | def ShuffleNetG2(): 87 | cfg = { 88 | 'out_planes': [200,400,800], 89 | 'num_blocks': [4,8,4], 90 | 'groups': 2 91 | } 92 | return ShuffleNet(cfg) 93 | 94 | def ShuffleNetG3(): 95 | cfg = { 96 | 'out_planes': [240,480,960], 97 | 'num_blocks': [4,8,4], 98 | 'groups': 3 99 | } 100 | return ShuffleNet(cfg) 101 | 102 | 103 | def test(): 104 | net = ShuffleNetG2() 105 | x = torch.randn(1,3,32,32) 106 | y = net(x) 107 | print(y) 108 | 109 | # test() 110 | -------------------------------------------------------------------------------- /code/train_generator/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N,C,H,W = x.size() 18 | g = self.groups 19 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 20 | 21 | 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_planes, out_planes, stride, groups): 24 | super(Bottleneck, self).__init__() 25 | self.stride = stride 26 | 27 | mid_planes = out_planes/4 28 | g = 1 if in_planes==24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 48 | return out 49 | 50 | 51 | class ShuffleNet(nn.Module): 52 | def __init__(self, cfg): 53 | super(ShuffleNet, self).__init__() 54 | out_planes = cfg['out_planes'] 55 | num_blocks = cfg['num_blocks'] 56 | groups = cfg['groups'] 57 | 58 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(24) 60 | self.in_planes = 24 61 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 62 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 63 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 64 | self.linear = nn.Linear(out_planes[2], 10) 65 | 66 | def _make_layer(self, out_planes, num_blocks, groups): 67 | layers = [] 68 | for i in range(num_blocks): 69 | stride = 2 if i == 0 else 1 70 | cat_planes = self.in_planes if i == 0 else 0 71 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 72 | self.in_planes = out_planes 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = self.layer1(out) 78 | out = self.layer2(out) 79 | out = self.layer3(out) 80 | out = F.avg_pool2d(out, 4) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | return out 84 | 85 | 86 | def ShuffleNetG2(): 87 | cfg = { 88 | 'out_planes': [200,400,800], 89 | 'num_blocks': [4,8,4], 90 | 'groups': 2 91 | } 92 | return ShuffleNet(cfg) 93 | 94 | def ShuffleNetG3(): 95 | cfg = { 96 | 'out_planes': [240,480,960], 97 | 'num_blocks': [4,8,4], 98 | 'groups': 3 99 | } 100 | return ShuffleNet(cfg) 101 | 102 | 103 | def test(): 104 | net = ShuffleNetG2() 105 | x = torch.randn(1,3,32,32) 106 | y = net(x) 107 | print(y) 108 | 109 | # test() 110 | -------------------------------------------------------------------------------- /run_cifar10_rand_class_alexnet.sh: -------------------------------------------------------------------------------- 1 | # CIFAR-100 10 class Alexnet runs 2 | 3 | # train dcgan for cifar-100 10 classes 4 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/dcgan.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/alexnet/cifar_10_rand/dcgan/ --manualSeed 108 --niter 200 --batchSize 64 --network alexnet --proxy_ds_name 10_class_rand --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --num_classes 10 5 | 6 | mkdir val_data 7 | 8 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --dcgan_netG ./cifar100_run_models/alexnet/cifar_10_rand/dcgan/netG_epoch_199.pth --dcgan_out ./val_data/dcgan_val_data_cifar10_rand_alexnet.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --network alexnet 9 | 10 | 11 | # Hard label runs 12 | 13 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/train_student.py --from_scratch --dcgan_netG ./cifar100_run_models/alexnet/cifar_10_rand/dcgan/netG_epoch_199.pth --count 1 --network alexnet --dcgan_data --dcgan_data_ratio 0.8 --proxy_data --proxy_data_ratio 1 --pad_crop --name proxy_dcgan_45k_10_class_rand --proxy_ds_name 10_class_rand --val_data_dcgan ./val_data/dcgan_val_data_cifar10_rand_alexnet.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --true_dataset cifar10 --num_classes 10 --proxy_dataset cifar100 14 | 15 | # ./train_student/checkpoint/alexnet/cifar10_cifar100/1/last_epoch_best_student_model_proxy_dcgan_45k_10_class_rand.pth 16 | 17 | 18 | # Run div gan 19 | 20 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_gen.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/alexnet/out_step2_0_10/10_class_rand/2/ --manualSeed 108 --niter 100 --batchSize 64 --netG ./cifar100_run_models/alexnet/cifar_10_rand/dcgan/netG_epoch_199.pth --student_path ./train_student/checkpoint/alexnet/cifar10_cifar100/1/last_epoch_best_student_model_proxy_dcgan_45k_10_class_rand.pth --network alexnet --c_l 0 --d_l 10 --proxy_ds_name 10_class_rand --true_dataset cifar10 --num_classes 10 21 | 22 | 23 | # Div GAN + Random Crop (student trained from scratch on 0.5*Proxy + 0.5*DivGAN ) 24 | CUDA_VISIBLE_DEVICES=0 python code/train_student/train_student.py --div_gan_netG ./cifar100_run_models/alexnet/out_step2_0_10/10_class_rand/2/netG_epoch_99.pth --dcgan_netG ./cifar100_run_models/alexnet/cifar_10_rand/dcgan/netG_epoch_199.pth --count 2 --network alexnet --pad_crop --div_gan_data --div_gan_data_ratio 0.8 --proxy_data --proxy_data_ratio 1 --from_scratch --name from_scratch_div_gan_05_10_class_rand --proxy_ds_name 10_class_rand --val_data_dcgan ./val_data/dcgan_val_data_cifar10_rand_alexnet.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --proxy_dataset cifar100 --true_dataset cifar10 --num_classes 10 25 | 26 | 27 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --div_gan_netG ./cifar100_run_models/alexnet/out_step2_0_10/10_class_rand/2/netG_epoch_99.pth --network alexnet --div_gan_out ./val_data/div_gan_val_data_cifar10_rand_alexnet.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth 28 | 29 | 30 | # ./train_student/checkpoint/alexnet/cifar10_cifar100_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_10_class_rand.pth 31 | # Alternate training 32 | 33 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_generator_clone.py --dataroot ./data/ --dataset cifar100 --imageSize 32 --cuda --outf ./cifar100_run_models/alexnet/out_step2/cifar_10_rand/0_500/ --manualSeed 108 --niter 800 --batchSize 64 --netG ./cifar100_run_models/alexnet/out_step2_0_10/10_class_rand/2/netG_epoch_99.pth --student_path ./train_student/checkpoint/alexnet/cifar10_cifar100_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_10_class_rand.pth --network alexnet --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --warmup --name warmup_altr_0_500_1_p10 --auto-augment --c_l 0 --d_l 500 --proxy_ds_name 10_class_rand --val_data_dcgan ./val_data/dcgan_val_data_cifar10_rand_alexnet.pkl --val_data_degan ./val_data/div_gan_val_data_cifar10_rand_alexnet.pkl --true_dataset cifar10 --num_classes 10 34 | -------------------------------------------------------------------------------- /run_synthetic_alexnet.sh: -------------------------------------------------------------------------------- 1 | 2 | # train dcgan for synthetic data 3 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/dcgan.py --dataroot ./data/ --dataset synthetic --imageSize 32 --cuda --outf ./cifar100_run_models/alexnet/synthetic_grey/dcgan/ --manualSeed 108 --niter 200 --batchSize 64 --network alexnet --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --num_classes 10 --synthetic_dir ./synthetic_dataset/50k_samples/ --total_synth_samples 50000 --grey_scale 4 | 5 | mkdir val_data 6 | 7 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --dcgan_netG ./cifar100_run_models/alexnet/synthetic_grey/dcgan/netG_epoch_199.pth --dcgan_out ./val_data/dcgan_val_data_synthetic_alexnet_grey.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --network alexnet 8 | 9 | # train student 10 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/train_student.py --from_scratch --dcgan_netG ./cifar100_run_models/alexnet/synthetic_grey/dcgan/netG_epoch_199.pth --count 1 --dcgan_data --dcgan_data_ratio 0.5 --proxy_data --proxy_data_ratio 0.5 --pad_crop --name proxy_dcgan_45k_synthetic_grey_imgs --val_data_dcgan ./val_data/dcgan_val_data_synthetic_alexnet_grey.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --network alexnet --proxy_dataset synthetic --true_dataset cifar10 --num_classes 10 --synthetic_dir ./synthetic_dataset/50k_samples/ --total_synth_samples 50000 --grey_scale 11 | 12 | # ./train_student/checkpoint/alexnet/cifar10_synthetic/1/last_epoch_best_student_model_proxy_dcgan_45k_synthetic_grey_imgs.pth 13 | 14 | # Run div_gan 15 | 16 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_gen.py --dataroot ./data/ --dataset synthetic --imageSize 32 --cuda --outf ./cifar100_run_models/alexnet/out_step2_0_10/synthetic_grey/2/ --manualSeed 108 --niter 100 --batchSize 64 --netG ./cifar100_run_models/alexnet/synthetic_grey/dcgan/netG_epoch_199.pth --student_path ./train_student/checkpoint/alexnet/cifar10_synthetic/1/last_epoch_best_student_model_proxy_dcgan_45k_synthetic_grey_imgs.pth --network alexnet --c_l 0 --d_l 10 --true_dataset cifar10 --num_classes 10 --synthetic_dir ./synthetic_dataset/50k_samples/ --total_synth_samples 50000 --grey_scale 17 | 18 | # Divgan + Random Crop (student trained from scratch on 0.5*Proxy + 0.5*DivGAN ) 19 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/train_student.py --div_gan_netG ./cifar100_run_models/alexnet/out_step2_0_10/synthetic_grey/2/netG_epoch_99.pth --dcgan_netG ./cifar100_run_models/alexnet/synthetic_grey/dcgan/netG_epoch_199.pth --count 2 --network alexnet --pad_crop --div_gan_data --div_gan_data_ratio 0.5 --proxy_data --proxy_data_ratio 0.5 --from_scratch --name from_scratch_div_gan_05_synthetic_grey_imgs --val_data_dcgan ./val_data/dcgan_val_data_synthetic_alexnet_grey.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --proxy_dataset synthetic --true_dataset cifar10 --num_classes 10 --synthetic_dir ./synthetic_dataset/50k_samples/ --total_synth_samples 50000 --grey_scale 20 | 21 | 22 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --div_gan_netG ./cifar100_run_models/alexnet/out_step2_0_10/synthetic_grey/2/netG_epoch_99.pth --network alexnet --div_gan_out ./val_data/degan_val_data_synthetic_alexnet_grey.pkl --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth 23 | 24 | # ./train_student/checkpoint/alexnet/cifar10_synthetic_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_synthetic_grey_imgs.pth 25 | 26 | 27 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_generator_clone.py --dataroot ./data/ --dataset synthetic --imageSize 32 --cuda --outf ./cifar100_run_models/alexnet/out_step2/synthetic_grey/0_500/ --manualSeed 108 --niter 150 --batchSize 64 --netG ./cifar100_run_models/alexnet/out_step2_0_10/synthetic_grey/2/netG_epoch_99.pth --student_path ./train_student/checkpoint/alexnet/cifar10_synthetic_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_synthetic_grey_imgs.pth --network alexnet --teacher_path ./teacher_models/cifar10_alexnet_teacher_79.pth --warmup --name warmup_altr_0_10_p10 --auto-augment --c_l 0 --d_l 500 --val_data_dcgan ./val_data/dcgan_val_data_synthetic_alexnet_grey.pkl --val_data_degan ./val_data/degan_val_data_synthetic_alexnet_grey.pkl --true_dataset cifar10 --num_classes 10 --synthetic_dir ./synthetic_dataset/50k_samples/ --total_synth_samples 50000 --grey_scale -------------------------------------------------------------------------------- /run_synthetic_resnet.sh: -------------------------------------------------------------------------------- 1 | # train dcgan for synthetic data 2 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/dcgan.py --dataroot ./data/ --dataset synthetic --imageSize 32 --cuda --outf ./cifar100_run_models/alexnet/synthetic_grey/dcgan/ --manualSeed 108 --niter 200 --batchSize 64 --network resnet --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --num_classes 10 --synthetic_dir ./synthetic_dataset/50k_samples/ --grey_scale --total_synth_samples 50000 3 | 4 | #mkdir val_data 5 | 6 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --dcgan_netG ./cifar100_run_models/alexnet/synthetic_grey/dcgan/netG_epoch_199.pth --dcgan_out ./val_data/dcgan_val_data_synthetic_resnet_grey.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --network resnet 7 | 8 | 9 | # Hard label runs 10 | 11 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/train_student.py --from_scratch --dcgan_netG ./cifar100_run_models/alexnet/synthetic_grey/dcgan/netG_epoch_199.pth --count 1 --dcgan_data --dcgan_data_ratio 0.5 --proxy_data --proxy_data_ratio 0.5 --pad_crop --name proxy_dcgan_45k_synthetic_05_grey --val_data_dcgan ./val_data/dcgan_val_data_synthetic_resnet_grey.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --network resnet --proxy_dataset synthetic --true_dataset cifar10 --num_classes 10 --synthetic_dir ./synthetic_dataset/50k_samples/ --grey_scale --total_synth_samples 50000 12 | 13 | 14 | # ./train_student/checkpoint/resnet/cifar10_synthetic/1/last_epoch_best_student_model_proxy_dcgan_45k_synthetic_05_grey.pth 15 | 16 | 17 | # Run div_gan 18 | 19 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_gen.py --dataroot ./data/ --dataset synthetic --imageSize 32 --cuda --outf ./cifar100_run_models/resnet/out_step2_0_10/synthetic_grey/2/ --manualSeed 108 --niter 100 --batchSize 64 --netG ./cifar100_run_models/alexnet/synthetic_grey/dcgan/netG_epoch_199.pth --student_path ./train_student/checkpoint/resnet/cifar10_synthetic/1/last_epoch_best_student_model_proxy_dcgan_45k_synthetic_05_grey.pth --network resnet --c_l 0 --d_l 10 --true_dataset cifar10 --num_classes 10 --synthetic_dir ./synthetic_dataset/50k_samples/ --grey_scale --total_synth_samples 50000 20 | 21 | # DivGAN + Random Crop (student trained from scratch on 0.5*Proxy + 0.5*DivGAN ) 22 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/train_student.py --div_gan_netG ./cifar100_run_models/resnet/out_step2_0_10/synthetic_grey/2/netG_epoch_99.pth --dcgan_netG ./cifar100_run_models/alexnet/synthetic_grey/dcgan/netG_epoch_199.pth --count 2 --network resnet --pad_crop --div_gan_data --div_gan_data_ratio 0.5 --proxy_data --proxy_data_ratio 0.5 --from_scratch --name from_scratch_div_gan_05_synthetic_grey --val_data_dcgan ./val_data/dcgan_val_data_synthetic_resnet_grey.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --proxy_dataset synthetic --true_dataset cifar10 --num_classes 10 --synthetic_dir ./synthetic_dataset/50k_samples/ --grey_scale --total_synth_samples 50000 23 | 24 | 25 | CUDA_VISIBLE_DEVICES=0 python ./code/train_student/generate_val_data.py --div_gan_netG ./cifar100_run_models/resnet/out_step2_0_10/synthetic_grey/2/netG_epoch_99.pth --network resnet --div_gan_out ./val_data/div_gan_val_data_synthetic_resnet_grey.pkl --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth 26 | 27 | 28 | # ./train_student/checkpoint/resnet/cifar10_synthetic_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_synthetic_grey.pth 29 | 30 | # Alternate training 31 | 32 | CUDA_VISIBLE_DEVICES=0 python ./code/train_generator/train_generator_clone.py --dataroot ./data/ --dataset synthetic --imageSize 32 --cuda --outf ./cifar100_run_models/resnet/out_step2/synthetic_grey/0_500/ --manualSeed 108 --niter 150 --batchSize 64 --netG ./cifar100_run_models/resnet/out_step2_0_10/synthetic_grey/2/netG_epoch_99.pth --student_path ./train_student/checkpoint/resnet/cifar10_synthetic_div_gan/2/last_epoch_best_student_model_from_scratch_div_gan_05_synthetic_grey.pth --network resnet --teacher_path ./teacher_models/cifar10_resnet18_teacher_93.pth --warmup --name warmup_altr_0_10_p10 --auto-augment --c_l 0 --d_l 500 --val_data_dcgan ./val_data/dcgan_val_data_synthetic_resnet_grey.pkl --val_data_degan ./val_data/div_gan_val_data_synthetic_resnet_grey.pkl --true_dataset cifar10 --num_classes 10 --synthetic_dir ./synthetic_dataset/50k_samples/ --grey_scale --total_synth_samples 50000 -------------------------------------------------------------------------------- /models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(planes) 23 | ) 24 | 25 | # SE layers 26 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 27 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | 33 | # Squeeze 34 | w = F.avg_pool2d(out, out.size(2)) 35 | w = F.relu(self.fc1(w)) 36 | w = F.sigmoid(self.fc2(w)) 37 | # Excitation 38 | out = out * w # New broadcasting feature from v0.2! 39 | 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class PreActBlock(nn.Module): 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(PreActBlock, self).__init__() 48 | self.bn1 = nn.BatchNorm2d(in_planes) 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 52 | 53 | if stride != 1 or in_planes != planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 56 | ) 57 | 58 | # SE layers 59 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 60 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(x)) 64 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 65 | out = self.conv1(out) 66 | out = self.conv2(F.relu(self.bn2(out))) 67 | 68 | # Squeeze 69 | w = F.avg_pool2d(out, out.size(2)) 70 | w = F.relu(self.fc1(w)) 71 | w = F.sigmoid(self.fc2(w)) 72 | # Excitation 73 | out = out * w 74 | 75 | out += shortcut 76 | return out 77 | 78 | 79 | class SENet(nn.Module): 80 | def __init__(self, block, num_blocks, num_classes=10): 81 | super(SENet, self).__init__() 82 | self.in_planes = 64 83 | 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(64) 86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 90 | self.linear = nn.Linear(512, num_classes) 91 | 92 | def _make_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1]*(num_blocks-1) 94 | layers = [] 95 | for stride in strides: 96 | layers.append(block(self.in_planes, planes, stride)) 97 | self.in_planes = planes 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = F.avg_pool2d(out, 4) 107 | out = out.view(out.size(0), -1) 108 | out = self.linear(out) 109 | return out 110 | 111 | 112 | def SENet18(): 113 | return SENet(PreActBlock, [2,2,2,2]) 114 | 115 | 116 | def test(): 117 | net = SENet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /code/train_generator/models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(planes) 23 | ) 24 | 25 | # SE layers 26 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 27 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | 33 | # Squeeze 34 | w = F.avg_pool2d(out, out.size(2)) 35 | w = F.relu(self.fc1(w)) 36 | w = F.sigmoid(self.fc2(w)) 37 | # Excitation 38 | out = out * w # New broadcasting feature from v0.2! 39 | 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class PreActBlock(nn.Module): 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(PreActBlock, self).__init__() 48 | self.bn1 = nn.BatchNorm2d(in_planes) 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 52 | 53 | if stride != 1 or in_planes != planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 56 | ) 57 | 58 | # SE layers 59 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 60 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(x)) 64 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 65 | out = self.conv1(out) 66 | out = self.conv2(F.relu(self.bn2(out))) 67 | 68 | # Squeeze 69 | w = F.avg_pool2d(out, out.size(2)) 70 | w = F.relu(self.fc1(w)) 71 | w = F.sigmoid(self.fc2(w)) 72 | # Excitation 73 | out = out * w 74 | 75 | out += shortcut 76 | return out 77 | 78 | 79 | class SENet(nn.Module): 80 | def __init__(self, block, num_blocks, num_classes=10): 81 | super(SENet, self).__init__() 82 | self.in_planes = 64 83 | 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(64) 86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 90 | self.linear = nn.Linear(512, num_classes) 91 | 92 | def _make_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1]*(num_blocks-1) 94 | layers = [] 95 | for stride in strides: 96 | layers.append(block(self.in_planes, planes, stride)) 97 | self.in_planes = planes 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = F.avg_pool2d(out, 4) 107 | out = out.view(out.size(0), -1) 108 | out = self.linear(out) 109 | return out 110 | 111 | 112 | def SENet18(): 113 | return SENet(PreActBlock, [2,2,2,2]) 114 | 115 | 116 | def test(): 117 | net = SENet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /code/train_student/models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(planes) 23 | ) 24 | 25 | # SE layers 26 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 27 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | 33 | # Squeeze 34 | w = F.avg_pool2d(out, out.size(2)) 35 | w = F.relu(self.fc1(w)) 36 | w = F.sigmoid(self.fc2(w)) 37 | # Excitation 38 | out = out * w # New broadcasting feature from v0.2! 39 | 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class PreActBlock(nn.Module): 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(PreActBlock, self).__init__() 48 | self.bn1 = nn.BatchNorm2d(in_planes) 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 52 | 53 | if stride != 1 or in_planes != planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 56 | ) 57 | 58 | # SE layers 59 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 60 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(x)) 64 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 65 | out = self.conv1(out) 66 | out = self.conv2(F.relu(self.bn2(out))) 67 | 68 | # Squeeze 69 | w = F.avg_pool2d(out, out.size(2)) 70 | w = F.relu(self.fc1(w)) 71 | w = F.sigmoid(self.fc2(w)) 72 | # Excitation 73 | out = out * w 74 | 75 | out += shortcut 76 | return out 77 | 78 | 79 | class SENet(nn.Module): 80 | def __init__(self, block, num_blocks, num_classes=10): 81 | super(SENet, self).__init__() 82 | self.in_planes = 64 83 | 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(64) 86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 90 | self.linear = nn.Linear(512, num_classes) 91 | 92 | def _make_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1]*(num_blocks-1) 94 | layers = [] 95 | for stride in strides: 96 | layers.append(block(self.in_planes, planes, stride)) 97 | self.in_planes = planes 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = F.avg_pool2d(out, 4) 107 | out = out.view(out.size(0), -1) 108 | out = self.linear(out) 109 | return out 110 | 111 | 112 | def SENet18(): 113 | return SENet(PreActBlock, [2,2,2,2]) 114 | 115 | 116 | def test(): 117 | net = SENet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | 65 | class PreActResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(PreActResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def PreActResNet18(): 98 | return PreActResNet(PreActBlock, [2,2,2,2]) 99 | 100 | def PreActResNet34(): 101 | return PreActResNet(PreActBlock, [3,4,6,3]) 102 | 103 | def PreActResNet50(): 104 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 105 | 106 | def PreActResNet101(): 107 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 108 | 109 | def PreActResNet152(): 110 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 111 | 112 | 113 | def test(): 114 | net = PreActResNet18() 115 | y = net((torch.randn(1,3,32,32))) 116 | print(y.size()) 117 | 118 | # test() 119 | -------------------------------------------------------------------------------- /code/train_generator/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | 65 | class PreActResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(PreActResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def PreActResNet18(): 98 | return PreActResNet(PreActBlock, [2,2,2,2]) 99 | 100 | def PreActResNet34(): 101 | return PreActResNet(PreActBlock, [3,4,6,3]) 102 | 103 | def PreActResNet50(): 104 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 105 | 106 | def PreActResNet101(): 107 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 108 | 109 | def PreActResNet152(): 110 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 111 | 112 | 113 | def test(): 114 | net = PreActResNet18() 115 | y = net((torch.randn(1,3,32,32))) 116 | print(y.size()) 117 | 118 | # test() 119 | -------------------------------------------------------------------------------- /code/train_student/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | 65 | class PreActResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(PreActResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def PreActResNet18(): 98 | return PreActResNet(PreActBlock, [2,2,2,2]) 99 | 100 | def PreActResNet34(): 101 | return PreActResNet(PreActBlock, [3,4,6,3]) 102 | 103 | def PreActResNet50(): 104 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 105 | 106 | def PreActResNet101(): 107 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 108 | 109 | def PreActResNet152(): 110 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 111 | 112 | 113 | def test(): 114 | net = PreActResNet18() 115 | y = net((torch.randn(1,3,32,32))) 116 | print(y.size()) 117 | 118 | # test() 119 | -------------------------------------------------------------------------------- /models/dla_simple.py: -------------------------------------------------------------------------------- 1 | '''Simplified version of DLA in PyTorch. 2 | 3 | Note this implementation is not identical to the original paper version. 4 | But it seems works fine. 5 | 6 | See dla.py for the original paper version. 7 | 8 | Reference: 9 | Deep Layer Aggregation. https://arxiv.org/abs/1707.06484 10 | ''' 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d( 22 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 25 | stride=1, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | 28 | self.shortcut = nn.Sequential() 29 | if stride != 1 or in_planes != self.expansion*planes: 30 | self.shortcut = nn.Sequential( 31 | nn.Conv2d(in_planes, self.expansion*planes, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(self.expansion*planes) 34 | ) 35 | 36 | def forward(self, x): 37 | out = F.relu(self.bn1(self.conv1(x))) 38 | out = self.bn2(self.conv2(out)) 39 | out += self.shortcut(x) 40 | out = F.relu(out) 41 | return out 42 | 43 | 44 | class Root(nn.Module): 45 | def __init__(self, in_channels, out_channels, kernel_size=1): 46 | super(Root, self).__init__() 47 | self.conv = nn.Conv2d( 48 | in_channels, out_channels, kernel_size, 49 | stride=1, padding=(kernel_size - 1) // 2, bias=False) 50 | self.bn = nn.BatchNorm2d(out_channels) 51 | 52 | def forward(self, xs): 53 | x = torch.cat(xs, 1) 54 | out = F.relu(self.bn(self.conv(x))) 55 | return out 56 | 57 | 58 | class Tree(nn.Module): 59 | def __init__(self, block, in_channels, out_channels, level=1, stride=1): 60 | super(Tree, self).__init__() 61 | self.root = Root(2*out_channels, out_channels) 62 | if level == 1: 63 | self.left_tree = block(in_channels, out_channels, stride=stride) 64 | self.right_tree = block(out_channels, out_channels, stride=1) 65 | else: 66 | self.left_tree = Tree(block, in_channels, 67 | out_channels, level=level-1, stride=stride) 68 | self.right_tree = Tree(block, out_channels, 69 | out_channels, level=level-1, stride=1) 70 | 71 | def forward(self, x): 72 | out1 = self.left_tree(x) 73 | out2 = self.right_tree(out1) 74 | out = self.root([out1, out2]) 75 | return out 76 | 77 | 78 | class SimpleDLA(nn.Module): 79 | def __init__(self, block=BasicBlock, num_classes=10): 80 | super(SimpleDLA, self).__init__() 81 | self.base = nn.Sequential( 82 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), 83 | nn.BatchNorm2d(16), 84 | nn.ReLU(True) 85 | ) 86 | 87 | self.layer1 = nn.Sequential( 88 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False), 89 | nn.BatchNorm2d(16), 90 | nn.ReLU(True) 91 | ) 92 | 93 | self.layer2 = nn.Sequential( 94 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False), 95 | nn.BatchNorm2d(32), 96 | nn.ReLU(True) 97 | ) 98 | 99 | self.layer3 = Tree(block, 32, 64, level=1, stride=1) 100 | self.layer4 = Tree(block, 64, 128, level=2, stride=2) 101 | self.layer5 = Tree(block, 128, 256, level=2, stride=2) 102 | self.layer6 = Tree(block, 256, 512, level=1, stride=2) 103 | self.linear = nn.Linear(512, num_classes) 104 | 105 | def forward(self, x): 106 | out = self.base(x) 107 | out = self.layer1(out) 108 | out = self.layer2(out) 109 | out = self.layer3(out) 110 | out = self.layer4(out) 111 | out = self.layer5(out) 112 | out = self.layer6(out) 113 | out = F.avg_pool2d(out, 4) 114 | out = out.view(out.size(0), -1) 115 | out = self.linear(out) 116 | return out 117 | 118 | 119 | def test(): 120 | net = SimpleDLA() 121 | print(net) 122 | x = torch.randn(1, 3, 32, 32) 123 | y = net(x) 124 | print(y.size()) 125 | 126 | 127 | if __name__ == '__main__': 128 | test() 129 | -------------------------------------------------------------------------------- /code/train_student/models/dla_simple.py: -------------------------------------------------------------------------------- 1 | '''Simplified version of DLA in PyTorch. 2 | 3 | Note this implementation is not identical to the original paper version. 4 | But it seems works fine. 5 | 6 | See dla.py for the original paper version. 7 | 8 | Reference: 9 | Deep Layer Aggregation. https://arxiv.org/abs/1707.06484 10 | ''' 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d( 22 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 25 | stride=1, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | 28 | self.shortcut = nn.Sequential() 29 | if stride != 1 or in_planes != self.expansion*planes: 30 | self.shortcut = nn.Sequential( 31 | nn.Conv2d(in_planes, self.expansion*planes, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(self.expansion*planes) 34 | ) 35 | 36 | def forward(self, x): 37 | out = F.relu(self.bn1(self.conv1(x))) 38 | out = self.bn2(self.conv2(out)) 39 | out += self.shortcut(x) 40 | out = F.relu(out) 41 | return out 42 | 43 | 44 | class Root(nn.Module): 45 | def __init__(self, in_channels, out_channels, kernel_size=1): 46 | super(Root, self).__init__() 47 | self.conv = nn.Conv2d( 48 | in_channels, out_channels, kernel_size, 49 | stride=1, padding=(kernel_size - 1) // 2, bias=False) 50 | self.bn = nn.BatchNorm2d(out_channels) 51 | 52 | def forward(self, xs): 53 | x = torch.cat(xs, 1) 54 | out = F.relu(self.bn(self.conv(x))) 55 | return out 56 | 57 | 58 | class Tree(nn.Module): 59 | def __init__(self, block, in_channels, out_channels, level=1, stride=1): 60 | super(Tree, self).__init__() 61 | self.root = Root(2*out_channels, out_channels) 62 | if level == 1: 63 | self.left_tree = block(in_channels, out_channels, stride=stride) 64 | self.right_tree = block(out_channels, out_channels, stride=1) 65 | else: 66 | self.left_tree = Tree(block, in_channels, 67 | out_channels, level=level-1, stride=stride) 68 | self.right_tree = Tree(block, out_channels, 69 | out_channels, level=level-1, stride=1) 70 | 71 | def forward(self, x): 72 | out1 = self.left_tree(x) 73 | out2 = self.right_tree(out1) 74 | out = self.root([out1, out2]) 75 | return out 76 | 77 | 78 | class SimpleDLA(nn.Module): 79 | def __init__(self, block=BasicBlock, num_classes=10): 80 | super(SimpleDLA, self).__init__() 81 | self.base = nn.Sequential( 82 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), 83 | nn.BatchNorm2d(16), 84 | nn.ReLU(True) 85 | ) 86 | 87 | self.layer1 = nn.Sequential( 88 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False), 89 | nn.BatchNorm2d(16), 90 | nn.ReLU(True) 91 | ) 92 | 93 | self.layer2 = nn.Sequential( 94 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False), 95 | nn.BatchNorm2d(32), 96 | nn.ReLU(True) 97 | ) 98 | 99 | self.layer3 = Tree(block, 32, 64, level=1, stride=1) 100 | self.layer4 = Tree(block, 64, 128, level=2, stride=2) 101 | self.layer5 = Tree(block, 128, 256, level=2, stride=2) 102 | self.layer6 = Tree(block, 256, 512, level=1, stride=2) 103 | self.linear = nn.Linear(512, num_classes) 104 | 105 | def forward(self, x): 106 | out = self.base(x) 107 | out = self.layer1(out) 108 | out = self.layer2(out) 109 | out = self.layer3(out) 110 | out = self.layer4(out) 111 | out = self.layer5(out) 112 | out = self.layer6(out) 113 | out = F.avg_pool2d(out, 4) 114 | out = out.view(out.size(0), -1) 115 | out = self.linear(out) 116 | return out 117 | 118 | 119 | def test(): 120 | net = SimpleDLA() 121 | print(net) 122 | x = torch.randn(1, 3, 32, 32) 123 | y = net(x) 124 | print(y.size()) 125 | 126 | 127 | if __name__ == '__main__': 128 | test() 129 | -------------------------------------------------------------------------------- /code/train_generator/models/dla_simple.py: -------------------------------------------------------------------------------- 1 | '''Simplified version of DLA in PyTorch. 2 | 3 | Note this implementation is not identical to the original paper version. 4 | But it seems works fine. 5 | 6 | See dla.py for the original paper version. 7 | 8 | Reference: 9 | Deep Layer Aggregation. https://arxiv.org/abs/1707.06484 10 | ''' 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d( 22 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 25 | stride=1, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | 28 | self.shortcut = nn.Sequential() 29 | if stride != 1 or in_planes != self.expansion*planes: 30 | self.shortcut = nn.Sequential( 31 | nn.Conv2d(in_planes, self.expansion*planes, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(self.expansion*planes) 34 | ) 35 | 36 | def forward(self, x): 37 | out = F.relu(self.bn1(self.conv1(x))) 38 | out = self.bn2(self.conv2(out)) 39 | out += self.shortcut(x) 40 | out = F.relu(out) 41 | return out 42 | 43 | 44 | class Root(nn.Module): 45 | def __init__(self, in_channels, out_channels, kernel_size=1): 46 | super(Root, self).__init__() 47 | self.conv = nn.Conv2d( 48 | in_channels, out_channels, kernel_size, 49 | stride=1, padding=(kernel_size - 1) // 2, bias=False) 50 | self.bn = nn.BatchNorm2d(out_channels) 51 | 52 | def forward(self, xs): 53 | x = torch.cat(xs, 1) 54 | out = F.relu(self.bn(self.conv(x))) 55 | return out 56 | 57 | 58 | class Tree(nn.Module): 59 | def __init__(self, block, in_channels, out_channels, level=1, stride=1): 60 | super(Tree, self).__init__() 61 | self.root = Root(2*out_channels, out_channels) 62 | if level == 1: 63 | self.left_tree = block(in_channels, out_channels, stride=stride) 64 | self.right_tree = block(out_channels, out_channels, stride=1) 65 | else: 66 | self.left_tree = Tree(block, in_channels, 67 | out_channels, level=level-1, stride=stride) 68 | self.right_tree = Tree(block, out_channels, 69 | out_channels, level=level-1, stride=1) 70 | 71 | def forward(self, x): 72 | out1 = self.left_tree(x) 73 | out2 = self.right_tree(out1) 74 | out = self.root([out1, out2]) 75 | return out 76 | 77 | 78 | class SimpleDLA(nn.Module): 79 | def __init__(self, block=BasicBlock, num_classes=10): 80 | super(SimpleDLA, self).__init__() 81 | self.base = nn.Sequential( 82 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), 83 | nn.BatchNorm2d(16), 84 | nn.ReLU(True) 85 | ) 86 | 87 | self.layer1 = nn.Sequential( 88 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False), 89 | nn.BatchNorm2d(16), 90 | nn.ReLU(True) 91 | ) 92 | 93 | self.layer2 = nn.Sequential( 94 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False), 95 | nn.BatchNorm2d(32), 96 | nn.ReLU(True) 97 | ) 98 | 99 | self.layer3 = Tree(block, 32, 64, level=1, stride=1) 100 | self.layer4 = Tree(block, 64, 128, level=2, stride=2) 101 | self.layer5 = Tree(block, 128, 256, level=2, stride=2) 102 | self.layer6 = Tree(block, 256, 512, level=1, stride=2) 103 | self.linear = nn.Linear(512, num_classes) 104 | 105 | def forward(self, x): 106 | out = self.base(x) 107 | out = self.layer1(out) 108 | out = self.layer2(out) 109 | out = self.layer3(out) 110 | out = self.layer4(out) 111 | out = self.layer5(out) 112 | out = self.layer6(out) 113 | out = F.avg_pool2d(out, 4) 114 | out = out.view(out.size(0), -1) 115 | out = self.linear(out) 116 | return out 117 | 118 | 119 | def test(): 120 | net = SimpleDLA() 121 | print(net) 122 | x = torch.randn(1, 3, 32, 32) 123 | y = net(x) 124 | print(y.size()) 125 | 126 | 127 | if __name__ == '__main__': 128 | test() 129 | -------------------------------------------------------------------------------- /models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class SepConv(nn.Module): 11 | '''Separable Convolution.''' 12 | def __init__(self, in_planes, out_planes, kernel_size, stride): 13 | super(SepConv, self).__init__() 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, 15 | kernel_size, stride, 16 | padding=(kernel_size-1)//2, 17 | bias=False, groups=in_planes) 18 | self.bn1 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | return self.bn1(self.conv1(x)) 22 | 23 | 24 | class CellA(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride=1): 26 | super(CellA, self).__init__() 27 | self.stride = stride 28 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 29 | if stride==2: 30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn1 = nn.BatchNorm2d(out_planes) 32 | 33 | def forward(self, x): 34 | y1 = self.sep_conv1(x) 35 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 36 | if self.stride==2: 37 | y2 = self.bn1(self.conv1(y2)) 38 | return F.relu(y1+y2) 39 | 40 | class CellB(nn.Module): 41 | def __init__(self, in_planes, out_planes, stride=1): 42 | super(CellB, self).__init__() 43 | self.stride = stride 44 | # Left branch 45 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 46 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 47 | # Right branch 48 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 49 | if stride==2: 50 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 51 | self.bn1 = nn.BatchNorm2d(out_planes) 52 | # Reduce channels 53 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 54 | self.bn2 = nn.BatchNorm2d(out_planes) 55 | 56 | def forward(self, x): 57 | # Left branch 58 | y1 = self.sep_conv1(x) 59 | y2 = self.sep_conv2(x) 60 | # Right branch 61 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 62 | if self.stride==2: 63 | y3 = self.bn1(self.conv1(y3)) 64 | y4 = self.sep_conv3(x) 65 | # Concat & reduce channels 66 | b1 = F.relu(y1+y2) 67 | b2 = F.relu(y3+y4) 68 | y = torch.cat([b1,b2], 1) 69 | return F.relu(self.bn2(self.conv2(y))) 70 | 71 | class PNASNet(nn.Module): 72 | def __init__(self, cell_type, num_cells, num_planes): 73 | super(PNASNet, self).__init__() 74 | self.in_planes = num_planes 75 | self.cell_type = cell_type 76 | 77 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(num_planes) 79 | 80 | self.layer1 = self._make_layer(num_planes, num_cells=6) 81 | self.layer2 = self._downsample(num_planes*2) 82 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 83 | self.layer4 = self._downsample(num_planes*4) 84 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 85 | 86 | self.linear = nn.Linear(num_planes*4, 10) 87 | 88 | def _make_layer(self, planes, num_cells): 89 | layers = [] 90 | for _ in range(num_cells): 91 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 92 | self.in_planes = planes 93 | return nn.Sequential(*layers) 94 | 95 | def _downsample(self, planes): 96 | layer = self.cell_type(self.in_planes, planes, stride=2) 97 | self.in_planes = planes 98 | return layer 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = self.layer5(out) 107 | out = F.avg_pool2d(out, 8) 108 | out = self.linear(out.view(out.size(0), -1)) 109 | return out 110 | 111 | 112 | def PNASNetA(): 113 | return PNASNet(CellA, num_cells=6, num_planes=44) 114 | 115 | def PNASNetB(): 116 | return PNASNet(CellB, num_cells=6, num_planes=32) 117 | 118 | 119 | def test(): 120 | net = PNASNetB() 121 | x = torch.randn(1,3,32,32) 122 | y = net(x) 123 | print(y) 124 | 125 | # test() 126 | -------------------------------------------------------------------------------- /code/train_generator/models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class SepConv(nn.Module): 11 | '''Separable Convolution.''' 12 | def __init__(self, in_planes, out_planes, kernel_size, stride): 13 | super(SepConv, self).__init__() 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, 15 | kernel_size, stride, 16 | padding=(kernel_size-1)//2, 17 | bias=False, groups=in_planes) 18 | self.bn1 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | return self.bn1(self.conv1(x)) 22 | 23 | 24 | class CellA(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride=1): 26 | super(CellA, self).__init__() 27 | self.stride = stride 28 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 29 | if stride==2: 30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn1 = nn.BatchNorm2d(out_planes) 32 | 33 | def forward(self, x): 34 | y1 = self.sep_conv1(x) 35 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 36 | if self.stride==2: 37 | y2 = self.bn1(self.conv1(y2)) 38 | return F.relu(y1+y2) 39 | 40 | class CellB(nn.Module): 41 | def __init__(self, in_planes, out_planes, stride=1): 42 | super(CellB, self).__init__() 43 | self.stride = stride 44 | # Left branch 45 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 46 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 47 | # Right branch 48 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 49 | if stride==2: 50 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 51 | self.bn1 = nn.BatchNorm2d(out_planes) 52 | # Reduce channels 53 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 54 | self.bn2 = nn.BatchNorm2d(out_planes) 55 | 56 | def forward(self, x): 57 | # Left branch 58 | y1 = self.sep_conv1(x) 59 | y2 = self.sep_conv2(x) 60 | # Right branch 61 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 62 | if self.stride==2: 63 | y3 = self.bn1(self.conv1(y3)) 64 | y4 = self.sep_conv3(x) 65 | # Concat & reduce channels 66 | b1 = F.relu(y1+y2) 67 | b2 = F.relu(y3+y4) 68 | y = torch.cat([b1,b2], 1) 69 | return F.relu(self.bn2(self.conv2(y))) 70 | 71 | class PNASNet(nn.Module): 72 | def __init__(self, cell_type, num_cells, num_planes): 73 | super(PNASNet, self).__init__() 74 | self.in_planes = num_planes 75 | self.cell_type = cell_type 76 | 77 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(num_planes) 79 | 80 | self.layer1 = self._make_layer(num_planes, num_cells=6) 81 | self.layer2 = self._downsample(num_planes*2) 82 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 83 | self.layer4 = self._downsample(num_planes*4) 84 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 85 | 86 | self.linear = nn.Linear(num_planes*4, 10) 87 | 88 | def _make_layer(self, planes, num_cells): 89 | layers = [] 90 | for _ in range(num_cells): 91 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 92 | self.in_planes = planes 93 | return nn.Sequential(*layers) 94 | 95 | def _downsample(self, planes): 96 | layer = self.cell_type(self.in_planes, planes, stride=2) 97 | self.in_planes = planes 98 | return layer 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = self.layer5(out) 107 | out = F.avg_pool2d(out, 8) 108 | out = self.linear(out.view(out.size(0), -1)) 109 | return out 110 | 111 | 112 | def PNASNetA(): 113 | return PNASNet(CellA, num_cells=6, num_planes=44) 114 | 115 | def PNASNetB(): 116 | return PNASNet(CellB, num_cells=6, num_planes=32) 117 | 118 | 119 | def test(): 120 | net = PNASNetB() 121 | x = torch.randn(1,3,32,32) 122 | y = net(x) 123 | print(y) 124 | 125 | # test() 126 | -------------------------------------------------------------------------------- /code/train_student/models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class SepConv(nn.Module): 11 | '''Separable Convolution.''' 12 | def __init__(self, in_planes, out_planes, kernel_size, stride): 13 | super(SepConv, self).__init__() 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, 15 | kernel_size, stride, 16 | padding=(kernel_size-1)//2, 17 | bias=False, groups=in_planes) 18 | self.bn1 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | return self.bn1(self.conv1(x)) 22 | 23 | 24 | class CellA(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride=1): 26 | super(CellA, self).__init__() 27 | self.stride = stride 28 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 29 | if stride==2: 30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn1 = nn.BatchNorm2d(out_planes) 32 | 33 | def forward(self, x): 34 | y1 = self.sep_conv1(x) 35 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 36 | if self.stride==2: 37 | y2 = self.bn1(self.conv1(y2)) 38 | return F.relu(y1+y2) 39 | 40 | class CellB(nn.Module): 41 | def __init__(self, in_planes, out_planes, stride=1): 42 | super(CellB, self).__init__() 43 | self.stride = stride 44 | # Left branch 45 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 46 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 47 | # Right branch 48 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 49 | if stride==2: 50 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 51 | self.bn1 = nn.BatchNorm2d(out_planes) 52 | # Reduce channels 53 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 54 | self.bn2 = nn.BatchNorm2d(out_planes) 55 | 56 | def forward(self, x): 57 | # Left branch 58 | y1 = self.sep_conv1(x) 59 | y2 = self.sep_conv2(x) 60 | # Right branch 61 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 62 | if self.stride==2: 63 | y3 = self.bn1(self.conv1(y3)) 64 | y4 = self.sep_conv3(x) 65 | # Concat & reduce channels 66 | b1 = F.relu(y1+y2) 67 | b2 = F.relu(y3+y4) 68 | y = torch.cat([b1,b2], 1) 69 | return F.relu(self.bn2(self.conv2(y))) 70 | 71 | class PNASNet(nn.Module): 72 | def __init__(self, cell_type, num_cells, num_planes): 73 | super(PNASNet, self).__init__() 74 | self.in_planes = num_planes 75 | self.cell_type = cell_type 76 | 77 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(num_planes) 79 | 80 | self.layer1 = self._make_layer(num_planes, num_cells=6) 81 | self.layer2 = self._downsample(num_planes*2) 82 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 83 | self.layer4 = self._downsample(num_planes*4) 84 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 85 | 86 | self.linear = nn.Linear(num_planes*4, 10) 87 | 88 | def _make_layer(self, planes, num_cells): 89 | layers = [] 90 | for _ in range(num_cells): 91 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 92 | self.in_planes = planes 93 | return nn.Sequential(*layers) 94 | 95 | def _downsample(self, planes): 96 | layer = self.cell_type(self.in_planes, planes, stride=2) 97 | self.in_planes = planes 98 | return layer 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = self.layer5(out) 107 | out = F.avg_pool2d(out, 8) 108 | out = self.linear(out.view(out.size(0), -1)) 109 | return out 110 | 111 | 112 | def PNASNetA(): 113 | return PNASNet(CellA, num_cells=6, num_planes=44) 114 | 115 | def PNASNetB(): 116 | return PNASNet(CellB, num_cells=6, num_planes=32) 117 | 118 | 119 | def test(): 120 | net = PNASNetB() 121 | x = torch.randn(1,3,32,32) 122 | y = net(x) 123 | print(y) 124 | 125 | # test() 126 | -------------------------------------------------------------------------------- /code/train_student/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d( 20 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 23 | stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, 30 | kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(self.expansion*planes) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 50 | stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * 53 | planes, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion*planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion*planes, 60 | kernel_size=1, stride=stride, bias=False), 61 | nn.BatchNorm2d(self.expansion*planes) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = F.relu(self.bn2(self.conv2(out))) 67 | out = self.bn3(self.conv3(out)) 68 | out += self.shortcut(x) 69 | out = F.relu(out) 70 | return out 71 | 72 | 73 | class ResNet(nn.Module): 74 | def __init__(self, block, num_blocks, num_classes=10): 75 | super(ResNet, self).__init__() 76 | self.in_planes = 64 77 | 78 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 79 | stride=1, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(64) 81 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 82 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 83 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 84 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 85 | self.linear = nn.Linear(512*block.expansion, num_classes) 86 | 87 | def _make_layer(self, block, planes, num_blocks, stride): 88 | strides = [stride] + [1]*(num_blocks-1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_planes, planes, stride)) 92 | self.in_planes = planes * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = self.layer4(out) 101 | out = F.avg_pool2d(out, 4) 102 | out = out.view(out.size(0), -1) 103 | out = self.linear(out) 104 | return out 105 | 106 | 107 | def ResNet18(cl=10): 108 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=cl) 109 | 110 | 111 | def ResNet34(): 112 | return ResNet(BasicBlock, [3, 4, 6, 3]) 113 | 114 | 115 | def ResNet50(): 116 | return ResNet(Bottleneck, [3, 4, 6, 3]) 117 | 118 | 119 | def ResNet101(): 120 | return ResNet(Bottleneck, [3, 4, 23, 3]) 121 | 122 | 123 | def ResNet152(): 124 | return ResNet(Bottleneck, [3, 8, 36, 3]) 125 | 126 | 127 | def test(): 128 | net = ResNet18() 129 | y = net(torch.randn(1, 3, 32, 32)) 130 | print(y.size()) 131 | 132 | # test() 133 | -------------------------------------------------------------------------------- /code/train_generator/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d( 20 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 23 | stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, 30 | kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(self.expansion*planes) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 50 | stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * 53 | planes, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion*planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion*planes, 60 | kernel_size=1, stride=stride, bias=False), 61 | nn.BatchNorm2d(self.expansion*planes) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = F.relu(self.bn2(self.conv2(out))) 67 | out = self.bn3(self.conv3(out)) 68 | out += self.shortcut(x) 69 | out = F.relu(out) 70 | return out 71 | 72 | 73 | class ResNet(nn.Module): 74 | def __init__(self, block, num_blocks, num_classes=10): 75 | super(ResNet, self).__init__() 76 | self.in_planes = 64 77 | 78 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 79 | stride=1, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(64) 81 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 82 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 83 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 84 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 85 | self.linear = nn.Linear(512*block.expansion, num_classes) 86 | 87 | def _make_layer(self, block, planes, num_blocks, stride): 88 | strides = [stride] + [1]*(num_blocks-1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_planes, planes, stride)) 92 | self.in_planes = planes * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = self.layer4(out) 101 | out = F.avg_pool2d(out, 4) 102 | out = out.view(out.size(0), -1) 103 | out = self.linear(out) 104 | return out 105 | 106 | 107 | def ResNet18(cl=10): 108 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=cl) 109 | 110 | 111 | def ResNet34(): 112 | return ResNet(BasicBlock, [3, 4, 6, 3]) 113 | 114 | 115 | def ResNet50(): 116 | return ResNet(Bottleneck, [3, 4, 6, 3]) 117 | 118 | 119 | def ResNet101(): 120 | return ResNet(Bottleneck, [3, 4, 23, 3]) 121 | 122 | 123 | def ResNet152(): 124 | return ResNet(Bottleneck, [3, 8, 36, 3]) 125 | 126 | 127 | def test(): 128 | net = ResNet18() 129 | y = net(torch.randn(1, 3, 32, 32)) 130 | print(y.size()) 131 | 132 | # test() 133 | --------------------------------------------------------------------------------