├── LICENSE ├── README.md ├── data_loader.py ├── experiments ├── CIFAR10 │ ├── baseline │ │ ├── cnn │ │ │ └── params.json │ │ ├── preresnet20 │ │ │ └── params.json │ │ ├── preresnet32 │ │ │ └── params.json │ │ ├── resnet18 │ │ │ └── params.json │ │ └── resnext29 │ │ │ └── params.json │ ├── kd_nasty_resnet18 │ │ ├── cnn │ │ │ └── params.json │ │ ├── nasty_resnet18 │ │ │ └── params.json │ │ ├── preresnet20 │ │ │ └── params.json │ │ └── preresnet32 │ │ │ └── params.json │ └── kd_normal_resnet18 │ │ ├── cnn │ │ └── params.json │ │ ├── preresnet20 │ │ └── params.json │ │ └── preresnet32 │ │ └── params.json ├── CIFAR100 │ ├── baseline │ │ ├── mobilenetv2 │ │ │ └── params.json │ │ ├── resnet18 │ │ │ └── params.json │ │ ├── resnet50 │ │ │ └── params.json │ │ ├── resnext29 │ │ │ └── params.json │ │ └── shufflenetv2 │ │ │ └── params.json │ ├── kd_nasty_resnet18 │ │ ├── mobilenetv2 │ │ │ └── params.json │ │ ├── nasty_resnet18 │ │ │ └── params.json │ │ └── shufflenetv2 │ │ │ └── params.json │ ├── kd_nasty_resnet50 │ │ ├── mobilenetv2 │ │ │ └── params.json │ │ ├── nasty_resnet50 │ │ │ └── params.json │ │ ├── resnet18 │ │ │ └── params.json │ │ └── shufflenetv2 │ │ │ └── params.json │ ├── kd_nasty_resnext29 │ │ ├── mobilenetv2 │ │ │ └── params.json │ │ ├── nasty_resnext29 │ │ │ └── params.json │ │ ├── resnet18 │ │ │ └── params.json │ │ └── shufflenetv2 │ │ │ └── params.json │ ├── kd_normal_resnet18 │ │ ├── mobilenetv2 │ │ │ └── params.json │ │ └── shufflenetv2 │ │ │ └── params.json │ ├── kd_normal_resnet50 │ │ ├── mobilenetv2 │ │ │ └── params.json │ │ ├── resnet18 │ │ │ └── params.json │ │ └── shufflenetv2 │ │ │ └── params.json │ └── kd_normal_resnext29 │ │ ├── mobilenetv2 │ │ └── params.json │ │ ├── resnet18 │ │ └── params.json │ │ └── shufflenetv2 │ │ └── params.json └── TinyImageNet │ ├── baseline │ ├── mobilenetv2 │ │ └── params.json │ ├── resnet18 │ │ └── params.json │ ├── resnet50 │ │ └── params.json │ ├── resnext29 │ │ └── params.json │ └── shufflenetv2 │ │ └── params.json │ ├── kd_nasty_resnet18 │ ├── mobilenetv2 │ │ └── params.json │ ├── nasty_resnet18 │ │ └── params.json │ └── shufflenetv2 │ │ └── params.json │ ├── kd_nasty_resnet50 │ ├── mobilenetv2 │ │ └── params.json │ ├── nasty_resnet50 │ │ └── params.json │ ├── resnet18 │ │ └── params.json │ └── shufflenetv2 │ │ └── params.json │ ├── kd_nasty_resnext29 │ ├── mobilenetv2 │ │ └── params.json │ ├── nasty_resnext29 │ │ └── params.json │ ├── resnet18 │ │ └── params.json │ └── shufflenetv2 │ │ └── params.json │ ├── kd_normal_resnet18 │ ├── mobilenetv2 │ │ └── params.json │ └── shufflenetv2 │ │ └── params.json │ ├── kd_normal_resnet50 │ ├── mobilenetv2 │ │ └── params.json │ ├── resnet18 │ │ └── params.json │ └── shufflenetv2 │ │ └── params.json │ └── kd_normal_resnext29 │ ├── mobilenetv2 │ └── params.json │ ├── resnet18 │ └── params.json │ └── shufflenetv2 │ └── params.json ├── model ├── __init__.py ├── densenet.py ├── mlp.py ├── mobilenetv2.py ├── net.py ├── preresnet.py ├── resnet.py ├── resnext.py └── shufflenetv2.py ├── requirements.txt ├── train_kd.py ├── train_nasty.py ├── train_scratch.py └── utils └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 VITA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Undistillable: Making A Nasty Teacher That CANNOT teach students 2 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 3 | 4 | ["Undistillable: Making A Nasty Teacher That CANNOT teach students"](https://openreview.net/forum?id=0zvfm-nZqQs) 5 | 6 | Haoyu Ma, Tianlong Chen, Ting-Kuei Hu, Chenyu You, Xiaohui Xie, Zhangyang Wang 7 | In ICLR 2021 Spotlight Oral 8 | 9 | 10 | 11 | ## Overview 12 | 13 | * We propose the concept of **Nasty Teacher**, a defensive approach to prevent knowledge leaking and unauthorized model cloning through KD without sacrificing performance. 14 | * We propose a simple yet efficient algorithm, called **self-undermining knowledge distillation**, to directly build a nasty teacher through self-training, requiring no additional dataset 15 | nor auxiliary network. 16 | 17 | 18 | ## Prerequisite 19 | We use Pytorch 1.4.0, and CUDA 10.1. You can install them with 20 | ~~~ 21 | conda install pytorch=1.4.0 torchvision=0.5.0 cudatoolkit=10.1 -c pytorch 22 | ~~~ 23 | It should also be applicable to other Pytorch and CUDA versions. 24 | 25 | 26 | Then install other packages by 27 | ~~~ 28 | pip install -r requirements.txt 29 | ~~~ 30 | 31 | ## Usage 32 | 33 | 34 | ### Teacher networks 35 | 36 | ##### Step 1: Train a normal teacher network 37 | 38 | ~~~ 39 | python train_scratch.py --save_path [XXX] 40 | ~~~ 41 | Here, [XXX] specifies the directory of `params.json`, which contains all hyperparameters to train a network. 42 | We already include all hyperparameters in `experiments` to reproduce the results in our paper. 43 | 44 | For example, normally train a ResNet18 on CIFAR-10 45 | ~~~ 46 | python train_scratch.py --save_path experiments/CIFAR10/baseline/resnet18 47 | ~~~ 48 | After finishing training, you will get `training.log`, `best_model.tar` in that directory. 49 | 50 | The normal teacher network will serve as the **adversarial network** for the training of the nasty teacher. 51 | 52 | 53 | 54 | ##### Step 2: Train a nasty teacher network 55 | ~~~ 56 | python train_nasty.py --save_path [XXX] 57 | ~~~ 58 | Again, [XXX] specifies the directory of `params.json`, 59 | which contains the information of adversarial networks and hyperparameters for training. 60 | You need to specify the architecture of adversarial network and its checkpoint in this file. 61 | 62 | 63 | For example, train a nasty ResNet18 64 | ~~~ 65 | python train_nasty.py --save_path experiments/CIFAR10/kd_nasty_resnet18/nasty_resnet18 66 | ~~~ 67 | 68 | 69 | ### Knowledge Distillation for Student networks 70 | 71 | You can train a student distilling from normal or nasty teachers by 72 | ~~~ 73 | python train_kd.py --save_path [XXX] 74 | ~~~ 75 | Again, [XXX] specifies the directory of `params.json`, 76 | which contains the information of student networks and teacher networks 77 | 78 | 79 | For example, 80 | * train a plain CNN distilling from a nasty ResNet18 81 | ~~~ 82 | python train_kd.py --save_path experiments/CIFAR10/kd_nasty_resnet18/cnn 83 | ~~~ 84 | 85 | * Train a plain CNN distilling from a normal ResNet18 86 | ~~~ 87 | python train_kd.py --save_path experiments/CIFAR10/kd_normal_resnet18/cnn 88 | ~~~ 89 | 90 | 91 | 92 | ## Citation 93 | ~~~ 94 | @inproceedings{ 95 | ma2021undistillable, 96 | title={Undistillable: Making A Nasty Teacher That {\{}CANNOT{\}} teach students}, 97 | author={Haoyu Ma and Tianlong Chen and Ting-Kuei Hu and Chenyu You and Xiaohui Xie and Zhangyang Wang}, 98 | booktitle={International Conference on Learning Representations}, 99 | year={2021}, 100 | url={https://openreview.net/forum?id=0zvfm-nZqQs} 101 | } 102 | ~~~ 103 | 104 | ## Acknowledgement 105 | * [Teacher-free KD](https://github.com/yuanli2333/Teacher-free-Knowledge-Distillation) 106 | * [DAFL](https://github.com/huawei-noah/Data-Efficient-Model-Compression/tree/master/DAFL) 107 | * [DeepInversion](https://github.com/NVlabs/DeepInversion) 108 | 109 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | CIFAR-10 CIFAR-100, Tiny-ImageNet data loader 3 | """ 4 | 5 | import random 6 | import os 7 | import numpy as np 8 | import torch 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | from torch.utils.data.sampler import SubsetRandomSampler 12 | 13 | 14 | def fetch_dataloader(types, params): 15 | """ 16 | Fetch and return train/dev dataloader with hyperparameters (params.subset_percent = 1.) 17 | """ 18 | # using random crops and horizontal flip for train set 19 | if params.augmentation: 20 | train_transformer = transforms.Compose([ 21 | transforms.RandomCrop(32, padding=4), 22 | transforms.RandomHorizontalFlip(), # randomly flip image horizontally 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))]) 25 | 26 | # data augmentation can be turned off 27 | else: 28 | train_transformer = transforms.Compose([ 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))]) 31 | 32 | # transformer for dev set 33 | dev_transformer = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))]) 36 | 37 | # ************************************************************************************ 38 | if params.dataset == 'cifar10': 39 | trainset = torchvision.datasets.CIFAR10(root='./data/data-cifar10', train=True, 40 | download=True, transform=train_transformer) 41 | devset = torchvision.datasets.CIFAR10(root='./data/data-cifar10', train=False, 42 | download=True, transform=dev_transformer) 43 | 44 | # ************************************************************************************ 45 | elif params.dataset == 'cifar100': 46 | trainset = torchvision.datasets.CIFAR100(root='./data/data-cifar100', train=True, 47 | download=True, transform=train_transformer) 48 | devset = torchvision.datasets.CIFAR100(root='./data/data-cifar100', train=False, 49 | download=True, transform=dev_transformer) 50 | 51 | # ************************************************************************************ 52 | elif params.dataset == 'tiny_imagenet': 53 | data_dir = './data/tiny-imagenet-200/' 54 | data_transforms = { 55 | 'train': transforms.Compose([ 56 | transforms.RandomRotation(20), 57 | transforms.RandomHorizontalFlip(0.5), 58 | transforms.ToTensor(), 59 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 60 | ]), 61 | 'val': transforms.Compose([ 62 | transforms.ToTensor(), 63 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 64 | ]) 65 | } 66 | train_dir = data_dir + 'train/' 67 | test_dir = data_dir + 'val/' 68 | trainset = torchvision.datasets.ImageFolder(train_dir, data_transforms['train']) 69 | devset = torchvision.datasets.ImageFolder(test_dir, data_transforms['val']) 70 | 71 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=params.batch_size, 72 | shuffle=True, num_workers=params.num_workers) 73 | 74 | devloader = torch.utils.data.DataLoader(devset, batch_size=params.batch_size, 75 | shuffle=False, num_workers=params.num_workers) 76 | 77 | if types == 'train': 78 | dl = trainloader 79 | else: 80 | dl = devloader 81 | 82 | return dl 83 | 84 | 85 | def fetch_subset_dataloader(types, params): 86 | """ 87 | Use only a subset of dataset for KD training, depending on params.subset_percent 88 | """ 89 | 90 | # using random crops and horizontal flip for train set 91 | if params.augmentation: 92 | train_transformer = transforms.Compose([ 93 | transforms.RandomCrop(32, padding=4), 94 | transforms.RandomHorizontalFlip(), # randomly flip image horizontally 95 | transforms.ToTensor(), 96 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))]) 97 | 98 | # data augmentation can be turned off 99 | else: 100 | train_transformer = transforms.Compose([ 101 | transforms.ToTensor(), 102 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))]) 103 | 104 | # transformer for dev set 105 | dev_transformer = transforms.Compose([ 106 | transforms.ToTensor(), 107 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))]) 108 | 109 | # ************************************************************************************ 110 | if params.dataset == 'cifar10': 111 | trainset = torchvision.datasets.CIFAR10(root='./data/data-cifar10', train=True, 112 | download=True, transform=train_transformer) 113 | devset = torchvision.datasets.CIFAR10(root='./data/data-cifar10', train=False, 114 | download=True, transform=dev_transformer) 115 | 116 | # ************************************************************************************ 117 | elif params.dataset == 'cifar100': 118 | trainset = torchvision.datasets.CIFAR100(root='./data/data-cifar100', train=True, 119 | download=True, transform=train_transformer) 120 | devset = torchvision.datasets.CIFAR100(root='./data/data-cifar100', train=False, 121 | download=True, transform=dev_transformer) 122 | 123 | # ************************************************************************************ 124 | elif params.dataset == 'tiny_imagenet': 125 | data_dir = './data/tiny-imagenet-200/' 126 | data_transforms = { 127 | 'train': transforms.Compose([ 128 | transforms.RandomRotation(20), 129 | transforms.RandomHorizontalFlip(0.5), 130 | transforms.ToTensor(), 131 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 132 | ]), 133 | 'val': transforms.Compose([ 134 | transforms.ToTensor(), 135 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 136 | ]) 137 | } 138 | train_dir = data_dir + 'train/' 139 | test_dir = data_dir + 'val/' 140 | trainset = torchvision.datasets.ImageFolder(train_dir, data_transforms['train']) 141 | devset = torchvision.datasets.ImageFolder(test_dir, data_transforms['val']) 142 | 143 | trainset_size = len(trainset) 144 | indices = list(range(trainset_size)) 145 | split = int(np.floor(params.subset_percent * trainset_size)) 146 | np.random.seed(230) 147 | np.random.shuffle(indices) 148 | 149 | train_sampler = SubsetRandomSampler(indices[:split]) 150 | 151 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=params.batch_size, 152 | sampler=train_sampler, num_workers=params.num_workers, pin_memory=params.cuda) 153 | 154 | devloader = torch.utils.data.DataLoader(devset, batch_size=params.batch_size, 155 | shuffle=False, num_workers=params.num_workers, pin_memory=params.cuda) 156 | 157 | if types == 'train': 158 | dl = trainloader 159 | else: 160 | dl = devloader 161 | 162 | return dl -------------------------------------------------------------------------------- /experiments/CIFAR10/baseline/cnn/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "net", 3 | "num_channels": 32, 4 | "dropout_rate": 0.0, 5 | 6 | "learning_rate": 1e-3, 7 | "schedule": [999], 8 | "gamma": 0.1, 9 | "batch_size": 128, 10 | "num_epochs": 100, 11 | "num_workers": 4, 12 | "augmentation": 1, 13 | "cuda": 1, 14 | 15 | "dataset": "cifar10" 16 | } 17 | -------------------------------------------------------------------------------- /experiments/CIFAR10/baseline/preresnet20/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "preresnet20", 3 | 4 | "learning_rate": 1e-1, 5 | "schedule": [80, 120], 6 | "gamma": 0.1, 7 | "batch_size": 128, 8 | "num_epochs": 160, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar10" 14 | } 15 | -------------------------------------------------------------------------------- /experiments/CIFAR10/baseline/preresnet32/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "preresnet32", 3 | 4 | "learning_rate": 1e-1, 5 | "schedule": [80, 120], 6 | "gamma": 0.1, 7 | "batch_size": 128, 8 | "num_epochs": 160, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar10" 14 | } 15 | -------------------------------------------------------------------------------- /experiments/CIFAR10/baseline/resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [80, 120], 6 | "gamma": 0.1, 7 | "batch_size": 128, 8 | "num_epochs": 160, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar10" 14 | } 15 | 16 | -------------------------------------------------------------------------------- /experiments/CIFAR10/baseline/resnext29/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnext29", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [80, 120], 6 | "gamma": 0.1, 7 | "batch_size": 128, 8 | "num_epochs": 160, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar10" 14 | } 15 | 16 | -------------------------------------------------------------------------------- /experiments/CIFAR10/kd_nasty_resnet18/cnn/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "net", 3 | "num_channels": 32, 4 | "dropout_rate": 0.0, 5 | 6 | "learning_rate": 1e-3, 7 | "schedule": [999], 8 | "gamma": 0.1, 9 | "batch_size": 128, 10 | "num_epochs": 100, 11 | "num_workers": 4, 12 | "augmentation": 1, 13 | "cuda": 1, 14 | 15 | "dataset": "cifar10", 16 | 17 | "teacher_model": "resnet18", 18 | "teacher_resume": "experiments/CIFAR10/kd_nasty_resnet18/nasty_resnet18/best_model.tar", 19 | "temperature": 4, 20 | "alpha": 0.9 21 | } 22 | -------------------------------------------------------------------------------- /experiments/CIFAR10/kd_nasty_resnet18/nasty_resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [80, 120], 6 | "gamma": 0.1, 7 | "batch_size": 128, 8 | "num_epochs": 160, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar10", 14 | 15 | "adversarial_model": "resnet18", 16 | "adversarial_resume": "experiments/CIFAR10/baseline/resnet18/best_model.tar", 17 | "temperature": 4, 18 | "weight": 0.04 19 | } 20 | 21 | -------------------------------------------------------------------------------- /experiments/CIFAR10/kd_nasty_resnet18/preresnet20/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "preresnet20", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [80, 120], 6 | "gamma": 0.1, 7 | "batch_size": 128, 8 | "num_epochs": 160, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar10", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/CIFAR10/kd_nasty_resnet18/nasty_resnet18/best_model.tar", 17 | "temperature": 4, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/CIFAR10/kd_nasty_resnet18/preresnet32/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "preresnet32", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [80, 120], 6 | "gamma": 0.1, 7 | "batch_size": 128, 8 | "num_epochs": 160, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar10", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/CIFAR10/kd_nasty_resnet18/nasty_resnet18/best_model.tar", 17 | "temperature": 4, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/CIFAR10/kd_normal_resnet18/cnn/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "net", 3 | "num_channels": 32, 4 | "dropout_rate": 0.0, 5 | 6 | "learning_rate": 1e-3, 7 | "schedule": [999], 8 | "gamma": 0.1, 9 | "batch_size": 128, 10 | "num_epochs": 100, 11 | "num_workers": 4, 12 | "augmentation": 1, 13 | "cuda": 1, 14 | 15 | "dataset": "cifar10", 16 | 17 | "teacher_model": "resnet18", 18 | "teacher_resume": "experiments/CIFAR10/baseline/resnet18/best_model.tar", 19 | "temperature": 4, 20 | "alpha": 0.9 21 | } 22 | -------------------------------------------------------------------------------- /experiments/CIFAR10/kd_normal_resnet18/preresnet20/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "preresnet20", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [80, 120], 6 | "gamma": 0.1, 7 | "batch_size": 128, 8 | "num_epochs": 160, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar10", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/CIFAR10/baseline/resnet18/best_model.tar", 17 | "temperature": 4, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/CIFAR10/kd_normal_resnet18/preresnet32/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "preresnet32", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [80, 120], 6 | "gamma": 0.1, 7 | "batch_size": 128, 8 | "num_epochs": 160, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar10", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/CIFAR10/baseline/resnet18/best_model.tar", 17 | "temperature": 4, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/CIFAR100/baseline/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100" 14 | } 15 | -------------------------------------------------------------------------------- /experiments/CIFAR100/baseline/resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100" 14 | } 15 | 16 | -------------------------------------------------------------------------------- /experiments/CIFAR100/baseline/resnet50/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet50", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100" 14 | } 15 | 16 | -------------------------------------------------------------------------------- /experiments/CIFAR100/baseline/resnext29/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnext29", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100" 14 | } -------------------------------------------------------------------------------- /experiments/CIFAR100/baseline/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100" 14 | } 15 | -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_nasty_resnet18/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/CIFAR100/kd_nasty_resnet18/nasty_resnet18/best_model.tar", 17 | 18 | "temperature": 20, 19 | "alpha": 0.9 20 | } -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_nasty_resnet18/nasty_resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "adversarial_model": "resnet18", 16 | "adversarial_resume": "experiments/CIFAR100/baseline/resnet18/best_model.tar", 17 | "temperature": 20, 18 | "weight": 0.005 19 | } 20 | 21 | -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_nasty_resnet18/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/CIFAR100/kd_nasty_resnet18/nasty_resnet18/best_model.tar", 17 | 18 | "temperature": 20, 19 | "alpha": 0.9 20 | } -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_nasty_resnet50/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/CIFAR100/kd_nasty_resnet50/nasty_resnet50/best_model.tar", 17 | 18 | "temperature": 20, 19 | "alpha": 0.9 20 | } -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_nasty_resnet50/nasty_resnet50/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet50", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "adversarial_model": "resnet50", 16 | "adversarial_resume": "experiments/CIFAR100/baseline/resnet50/best_model.tar", 17 | "temperature": 20, 18 | "weight": 0.005 19 | } 20 | 21 | -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_nasty_resnet50/resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/CIFAR100/kd_nasty_resnet50/nasty_resnet50/best_model.tar", 17 | 18 | "temperature": 20, 19 | "alpha": 0.9 20 | } -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_nasty_resnet50/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/CIFAR100/kd_nasty_resnet50/nasty_resnet50/best_model.tar", 17 | 18 | "temperature": 20, 19 | "alpha": 0.9 20 | } -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_nasty_resnext29/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/CIFAR100/kd_nasty_resnext29/nasty_resnext29/best_model.tar", 17 | 18 | "temperature": 20, 19 | "alpha": 0.9 20 | } -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_nasty_resnext29/nasty_resnext29/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnext29", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "adversarial_model": "resnext29", 16 | "adversarial_resume": "experiments/CIFAR100/baseline/resnext29/best_model.tar", 17 | "temperature": 20, 18 | "weight": 0.005 19 | } 20 | 21 | -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_nasty_resnext29/resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/CIFAR100/kd_nasty_resnext29/nasty_resnext29/best_model.tar", 17 | 18 | "temperature": 20, 19 | "alpha": 0.9 20 | } -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_nasty_resnext29/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/CIFAR100/kd_nasty_resnext29/nasty_resnext29/best_model.tar", 17 | 18 | "temperature": 20, 19 | "alpha": 0.9 20 | } -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_normal_resnet18/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/CIFAR100/baseline/resnet18/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_normal_resnet18/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/CIFAR100/baseline/resnet18/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_normal_resnet50/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/CIFAR100/baseline/resnet50/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_normal_resnet50/resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/CIFAR100/baseline/resnet50/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_normal_resnet50/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/CIFAR100/baseline/resnet50/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_normal_resnext29/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/CIFAR100/baseline/resnext29/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_normal_resnext29/resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/CIFAR100/baseline/resnext29/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/CIFAR100/kd_normal_resnext29/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "cifar100", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/CIFAR100/baseline/resnext29/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/baseline/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet" 14 | } 15 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/baseline/resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet" 14 | } 15 | 16 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/baseline/resnet50/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet50", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet" 14 | } 15 | 16 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/baseline/resnext29/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnext29", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet" 14 | } 15 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/baseline/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet" 14 | } 15 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_nasty_resnet18/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/TinyImageNet/kd_nasty_resnet18/nasty_resnet18/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_nasty_resnet18/nasty_resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "adversarial_model": "resnet18", 16 | "adversarial_resume": "experiments/TinyImageNet/baseline/resnet18/best_model.tar", 17 | "temperature": 20, 18 | "weight": 0.01 19 | } 20 | 21 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_nasty_resnet18/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/TinyImageNet/kd_nasty_resnet18/nasty_resnet18/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_nasty_resnet50/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/TinyImageNet/kd_nasty_resnet50/nasty_resnet50/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_nasty_resnet50/nasty_resnet50/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet50", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "adversarial_model": "resnet50", 16 | "adversarial_resume": "experiments/TinyImageNet/baseline/resnet50/best_model.tar", 17 | "temperature": 20, 18 | "weight": 0.01 19 | } 20 | 21 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_nasty_resnet50/resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/TinyImageNet/kd_nasty_resnet50/nasty_resnet50/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_nasty_resnet50/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/TinyImageNet/kd_nasty_resnet50/nasty_resnet50/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_nasty_resnext29/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/TinyImageNet/kd_nasty_resnext29/nasty_resnext29/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_nasty_resnext29/nasty_resnext29/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnext29", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "adversarial_model": "resnext29", 16 | "adversarial_resume": "experiments/TinyImageNet/baseline/resnext29/best_model.tar", 17 | "temperature": 20, 18 | "weight": 0.01 19 | } 20 | 21 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_nasty_resnext29/resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/TinyImageNet/kd_nasty_resnext29/nasty_resnext29/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_nasty_resnext29/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/TinyImageNet/kd_nasty_resnext29/nasty_resnext29/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_normal_resnet18/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/TinyImageNet/baseline/resnet18/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_normal_resnet18/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnet18", 16 | "teacher_resume": "experiments/TinyImageNet/baseline/resnet18/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_normal_resnet50/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/TinyImageNet/baseline/resnet50/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_normal_resnet50/resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/TinyImageNet/baseline/resnet50/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_normal_resnet50/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnet50", 16 | "teacher_resume": "experiments/TinyImageNet/baseline/resnet50/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_normal_resnext29/mobilenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mobilenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/TinyImageNet/baseline/resnext29/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_normal_resnext29/resnet18/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "resnet18", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/TinyImageNet/baseline/resnext29/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /experiments/TinyImageNet/kd_normal_resnext29/shufflenetv2/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "shufflenetv2", 3 | 4 | "learning_rate": 0.1, 5 | "schedule": [60, 120, 160], 6 | "gamma": 0.2, 7 | "batch_size": 128, 8 | "num_epochs": 200, 9 | "num_workers": 4, 10 | "augmentation": 1, 11 | "cuda": 1, 12 | 13 | "dataset": "tiny_imagenet", 14 | 15 | "teacher_model": "resnext29", 16 | "teacher_resume": "experiments/TinyImageNet/baseline/resnext29/best_model.tar", 17 | "temperature": 20, 18 | "alpha": 0.9 19 | } 20 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLP 2 | from .net import Net 3 | from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 4 | from .preresnet import PreResNet 5 | 6 | from .resnext import CifarResNeXt 7 | from .densenet import densenet121, densenet161, densenet169, densenet201 8 | from .mobilenetv2 import MobileNetV2 9 | from .shufflenetv2 import shufflenetv2 10 | 11 | -------------------------------------------------------------------------------- /model/densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | dense net in pytorch 3 | [1] Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger. 4 | Densely Connected Convolutional Networks 5 | https://arxiv.org/abs/1608.06993v5 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | 13 | #"""Bottleneck layers. Although each layer only produces k 14 | #output feature-maps, it typically has many more inputs. It 15 | #has been noted in [37, 11] that a 1×1 convolution can be in- 16 | #troduced as bottleneck layer before each 3×3 convolution 17 | #to reduce the number of input feature-maps, and thus to 18 | #improve computational efficiency.""" 19 | class Bottleneck(nn.Module): 20 | def __init__(self, in_channels, growth_rate): 21 | super().__init__() 22 | #"""In our experiments, we let each 1×1 convolution 23 | #produce 4k feature-maps.""" 24 | inner_channel = 4 * growth_rate 25 | 26 | #"""We find this design especially effective for DenseNet and 27 | #we refer to our network with such a bottleneck layer, i.e., 28 | #to the BN-ReLU-Conv(1×1)-BN-ReLU-Conv(3×3) version of H ` , 29 | #as DenseNet-B.""" 30 | self.bottle_neck = nn.Sequential( 31 | nn.BatchNorm2d(in_channels), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(in_channels, inner_channel, kernel_size=1, bias=False), 34 | nn.BatchNorm2d(inner_channel), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(inner_channel, growth_rate, kernel_size=3, padding=1, bias=False) 37 | ) 38 | 39 | def forward(self, x): 40 | return torch.cat([x, self.bottle_neck(x)], 1) 41 | 42 | #"""We refer to layers between blocks as transition 43 | #layers, which do convolution and pooling.""" 44 | class Transition(nn.Module): 45 | def __init__(self, in_channels, out_channels): 46 | super().__init__() 47 | #"""The transition layers used in our experiments 48 | #consist of a batch normalization layer and an 1×1 49 | #convolutional layer followed by a 2×2 average pooling 50 | #layer""". 51 | self.down_sample = nn.Sequential( 52 | nn.BatchNorm2d(in_channels), 53 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 54 | nn.AvgPool2d(2, stride=2) 55 | ) 56 | 57 | def forward(self, x): 58 | return self.down_sample(x) 59 | 60 | #DesneNet-BC 61 | #B stands for bottleneck layer(BN-RELU-CONV(1x1)-BN-RELU-CONV(3x3)) 62 | #C stands for compression factor(0<=theta<=1) 63 | class DenseNet(nn.Module): 64 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_class=100): 65 | super().__init__() 66 | self.growth_rate = growth_rate 67 | 68 | #"""Before entering the first dense block, a convolution 69 | #with 16 (or twice the growth rate for DenseNet-BC) 70 | #output channels is performed on the input images.""" 71 | inner_channels = 2 * growth_rate 72 | 73 | #For convolutional layers with kernel size 3×3, each 74 | #side of the inputs is zero-padded by one pixel to keep 75 | #the feature-map size fixed. 76 | self.conv1 = nn.Conv2d(3, inner_channels, kernel_size=3, padding=1, bias=False) 77 | 78 | self.features = nn.Sequential() 79 | 80 | for index in range(len(nblocks) - 1): 81 | self.features.add_module("dense_block_layer_{}".format(index), self._make_dense_layers(block, inner_channels, nblocks[index])) 82 | inner_channels += growth_rate * nblocks[index] 83 | 84 | #"""If a dense block contains m feature-maps, we let the 85 | #following transition layer generate θm output feature- 86 | #maps, where 0 < θ ≤ 1 is referred to as the compression 87 | #fac-tor. 88 | out_channels = int(reduction * inner_channels) # int() will automatic floor the value 89 | self.features.add_module("transition_layer_{}".format(index), Transition(inner_channels, out_channels)) 90 | inner_channels = out_channels 91 | 92 | self.features.add_module("dense_block{}".format(len(nblocks) - 1), self._make_dense_layers(block, inner_channels, nblocks[len(nblocks)-1])) 93 | inner_channels += growth_rate * nblocks[len(nblocks) - 1] 94 | self.features.add_module('bn', nn.BatchNorm2d(inner_channels)) 95 | self.features.add_module('relu', nn.ReLU(inplace=True)) 96 | 97 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 98 | 99 | self.linear = nn.Linear(inner_channels, num_class) 100 | 101 | def forward(self, x): 102 | output = self.conv1(x) 103 | output = self.features(output) 104 | output = self.avgpool(output) 105 | output = output.view(output.size()[0], -1) 106 | output = self.linear(output) 107 | return output 108 | 109 | def _make_dense_layers(self, block, in_channels, nblocks): 110 | dense_block = nn.Sequential() 111 | for index in range(nblocks): 112 | dense_block.add_module('bottle_neck_layer_{}'.format(index), block(in_channels, self.growth_rate)) 113 | in_channels += self.growth_rate 114 | return dense_block 115 | 116 | 117 | def densenet121(**kwargs): 118 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, **kwargs) 119 | 120 | def densenet169(**kwargs): 121 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32,**kwargs) 122 | 123 | def densenet201(**kwargs): 124 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32, **kwargs) 125 | 126 | def densenet161(**kwargs): 127 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48, **kwargs) -------------------------------------------------------------------------------- /model/mlp.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, num_class=10): 8 | super(MLP, self).__init__() 9 | self.fc1 = nn.Linear(32 * 32 * 3, 512) 10 | # linear layer (n_hidden -> hidden_2) 11 | self.fc2 = nn.Linear(512, 512) 12 | # linear layer (n_hidden -> 10) 13 | self.fc3 = nn.Linear(512, num_class) 14 | # dropout layer (p=0.2) 15 | # dropout prevents overfitting of data 16 | self.dropout = nn.Dropout(0.2) 17 | 18 | def forward(self, x): 19 | # flatten image input 20 | x = x.view(-1, 32 * 32 * 3) 21 | # add hidden layer, with relu activation function 22 | x = F.relu(self.fc1(x)) 23 | x = F.relu(self.fc2(x)) 24 | x = F.relu(self.fc3(x)) 25 | return x 26 | 27 | 28 | -------------------------------------------------------------------------------- /model/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """mobilenetv2 in pytorch 2 | [1] Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen 3 | MobileNetV2: Inverted Residuals and Linear Bottlenecks 4 | https://arxiv.org/abs/1801.04381 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class LinearBottleNeck(nn.Module): 13 | 14 | def __init__(self, in_channels, out_channels, stride, t=6, class_num=100): 15 | super().__init__() 16 | 17 | self.residual = nn.Sequential( 18 | nn.Conv2d(in_channels, in_channels * t, 1), 19 | nn.BatchNorm2d(in_channels * t), 20 | nn.ReLU6(inplace=True), 21 | 22 | nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t), 23 | nn.BatchNorm2d(in_channels * t), 24 | nn.ReLU6(inplace=True), 25 | 26 | nn.Conv2d(in_channels * t, out_channels, 1), 27 | nn.BatchNorm2d(out_channels) 28 | ) 29 | 30 | self.stride = stride 31 | self.in_channels = in_channels 32 | self.out_channels = out_channels 33 | 34 | def forward(self, x): 35 | 36 | residual = self.residual(x) 37 | 38 | if self.stride == 1 and self.in_channels == self.out_channels: 39 | residual += x 40 | 41 | return residual 42 | 43 | class MobileNetV2(nn.Module): 44 | 45 | def __init__(self, class_num=100): 46 | super().__init__() 47 | 48 | self.pre = nn.Sequential( 49 | nn.Conv2d(3, 32, 1, padding=1), 50 | nn.BatchNorm2d(32), 51 | nn.ReLU6(inplace=True) 52 | ) 53 | 54 | self.stage1 = LinearBottleNeck(32, 16, 1, 1) 55 | self.stage2 = self._make_stage(2, 16, 24, 2, 6) 56 | self.stage3 = self._make_stage(3, 24, 32, 2, 6) 57 | self.stage4 = self._make_stage(4, 32, 64, 2, 6) 58 | self.stage5 = self._make_stage(3, 64, 96, 1, 6) 59 | self.stage6 = self._make_stage(3, 96, 160, 1, 6) 60 | self.stage7 = LinearBottleNeck(160, 320, 1, 6) 61 | 62 | self.conv1 = nn.Sequential( 63 | nn.Conv2d(320, 1280, 1), 64 | nn.BatchNorm2d(1280), 65 | nn.ReLU6(inplace=True) 66 | ) 67 | 68 | self.conv2 = nn.Conv2d(1280, class_num, 1) 69 | 70 | def forward(self, x): 71 | x = self.pre(x) 72 | x = self.stage1(x) 73 | x = self.stage2(x) 74 | x = self.stage3(x) 75 | x = self.stage4(x) 76 | x = self.stage5(x) 77 | x = self.stage6(x) 78 | x = self.stage7(x) 79 | x = self.conv1(x) 80 | x = F.adaptive_avg_pool2d(x, 1) 81 | x = self.conv2(x) 82 | x = x.view(x.size(0), -1) 83 | 84 | return x 85 | 86 | def _make_stage(self, repeat, in_channels, out_channels, stride, t): 87 | 88 | layers = [] 89 | layers.append(LinearBottleNeck(in_channels, out_channels, stride, t)) 90 | 91 | while repeat - 1: 92 | layers.append(LinearBottleNeck(out_channels, out_channels, 1, t)) 93 | repeat -= 1 94 | 95 | return nn.Sequential(*layers) 96 | 97 | def mobilenetv2(**kwargs): 98 | return MobileNetV2(**kwargs) 99 | 100 | 101 | if __name__ == "__main__": 102 | model = mobilenetv2(class_num=100) 103 | x = torch.randn(2,3,64,64) 104 | y = model(x) 105 | print(y.shape) 106 | 107 | x = torch.randn(2,3,32,32) 108 | y = model(x) 109 | print(y.shape) -------------------------------------------------------------------------------- /model/net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | __all__ = ['Net'] 5 | 6 | 7 | class Net(nn.Module): 8 | """ 9 | This is the standard way to define your own network in PyTorch. You typically choose the components 10 | (e.g. LSTMs, linear layers etc.) of your network in the __init__ function. You then apply these layers 11 | on the input step-by-step in the forward function. You can use torch.nn.functional to apply functions 12 | 13 | such as F.relu, F.sigmoid, F.softmax, F.max_pool2d. Be careful to ensure your dimensions are correct after each 14 | step. You are encouraged to have a look at the network in pytorch/nlp/model/net.py to get a better sense of how 15 | you can go about defining your own network. 16 | 17 | The documentation for all the various components available o you is here: http://pytorch.org/docs/master/nn.html 18 | """ 19 | 20 | def __init__(self, num_class, params): 21 | """ 22 | We define an convolutional network that predicts the sign from an image. The components 23 | required are: 24 | 25 | Args: 26 | params: (Params) contains num_channels 27 | """ 28 | super(Net, self).__init__() 29 | self.num_channels = params.num_channels 30 | 31 | # each of the convolution layers below have the arguments (input_channels, output_channels, filter_size, 32 | # stride, padding). We also include batch normalisation layers that help stabilise training. 33 | # For more details on how to use these layers, check out the documentation. 34 | self.conv1 = nn.Conv2d(3, self.num_channels, 3, stride=1, padding=1) 35 | self.bn1 = nn.BatchNorm2d(self.num_channels) 36 | self.conv2 = nn.Conv2d(self.num_channels, self.num_channels*2, 3, stride=1, padding=1) 37 | self.bn2 = nn.BatchNorm2d(self.num_channels*2) 38 | self.conv3 = nn.Conv2d(self.num_channels*2, self.num_channels*4, 3, stride=1, padding=1) 39 | self.bn3 = nn.BatchNorm2d(self.num_channels*4) 40 | 41 | # 2 fully connected layers to transform the output of the convolution layers to the final output 42 | self.fc1 = nn.Linear(4*4*self.num_channels*4, self.num_channels*4) 43 | self.fcbn1 = nn.BatchNorm1d(self.num_channels*4) 44 | self.fc2 = nn.Linear(self.num_channels*4, num_class) 45 | self.dropout_rate = params.dropout_rate 46 | 47 | def forward(self, s): 48 | """ 49 | This function defines how we use the components of our network to operate on an input batch. 50 | 51 | Args: 52 | s: (Variable) contains a batch of images, of dimension batch_size x 3 x 32 x 32 . 53 | 54 | Returns: 55 | out: (Variable) dimension batch_size x 6 with the log probabilities for the labels of each image. 56 | 57 | Note: the dimensions after each step are provided 58 | """ 59 | # -> batch_size x 3 x 32 x 32 60 | # we apply the convolution layers, followed by batch normalisation, maxpool and relu x 3 61 | s = self.bn1(self.conv1(s)) # batch_size x num_channels x 32 x 32 62 | s = F.relu(F.max_pool2d(s, 2)) # batch_size x num_channels x 16 x 16 63 | s = self.bn2(self.conv2(s)) # batch_size x num_channels*2 x 16 x 16 64 | s = F.relu(F.max_pool2d(s, 2)) # batch_size x num_channels*2 x 8 x 8 65 | s = self.bn3(self.conv3(s)) # batch_size x num_channels*4 x 8 x 8 66 | s = F.relu(F.max_pool2d(s, 2)) # batch_size x num_channels*4 x 4 x 4 67 | 68 | # flatten the output for each image 69 | s = s.view(-1, 4*4*self.num_channels*4) # batch_size x 4*4*num_channels*4 70 | 71 | # apply 2 fully connected layers with dropout 72 | s = F.dropout(F.relu(self.fcbn1(self.fc1(s))), 73 | p=self.dropout_rate, training=self.training) # batch_size x self.num_channels*4 74 | s = self.fc2(s) # batch_size x 10 75 | 76 | return s 77 | 78 | 79 | -------------------------------------------------------------------------------- /model/preresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import math 12 | import numpy as np 13 | 14 | 15 | # __all__ = ['preresnet'] 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.bn1 = nn.BatchNorm2d(inplanes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.bn1(x) 40 | out = self.relu(out) 41 | out = self.conv1(out) 42 | 43 | out = self.bn2(out) 44 | out = self.relu(out) 45 | out = self.conv2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(Bottleneck, self).__init__() 60 | self.bn1 = nn.BatchNorm2d(inplanes) 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.bn1(x) 75 | out = self.relu(out) 76 | out = self.conv1(out) 77 | 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | out = self.conv2(out) 81 | 82 | out = self.bn3(out) 83 | out = self.relu(out) 84 | out = self.conv3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | 91 | return out 92 | 93 | 94 | class PreResNet(nn.Module): 95 | 96 | def __init__(self, depth, num_classes=10): 97 | super(PreResNet, self).__init__() 98 | # Model type specifies number of layers for CIFAR-10 model 99 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 100 | n = (depth - 2) // 6 101 | 102 | block = Bottleneck if depth >=44 else BasicBlock 103 | 104 | self.inplanes = 16 105 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 106 | bias=False) 107 | self.layer1 = self._make_layer(block, 16, n) 108 | self.layer2 = self._make_layer(block, 32, n, stride=2) 109 | self.layer3 = self._make_layer(block, 64, n, stride=2) 110 | self.bn = nn.BatchNorm2d(64 * block.expansion) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.avgpool = nn.AvgPool2d(8) 113 | self.fc = nn.Linear(64 * block.expansion, num_classes) 114 | 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 118 | m.weight.data.normal_(0, math.sqrt(2. / n)) 119 | elif isinstance(m, nn.BatchNorm2d): 120 | m.weight.data.fill_(1) 121 | m.bias.data.zero_() 122 | 123 | def _make_layer(self, block, planes, blocks, stride=1): 124 | downsample = None 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | nn.Conv2d(self.inplanes, planes * block.expansion, 128 | kernel_size=1, stride=stride, bias=False), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for i in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | 142 | x = self.layer1(x) # 32x32 143 | x = self.layer2(x) # 16x16 144 | x = self.layer3(x) # 8x8 145 | x = self.bn(x) 146 | x = self.relu(x) 147 | 148 | x = self.avgpool(x) 149 | x = x.view(x.size(0), -1) 150 | x = self.fc(x) 151 | 152 | return x 153 | 154 | 155 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | __all__ = ['ResNet', 'ResNet18', 'ResNet34', 'ResNet50', 'ResNet101', 'ResNet152'] 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(self.expansion*planes) 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = F.relu(self.bn2(self.conv2(out))) 63 | out = self.bn3(self.conv3(out)) 64 | out += self.shortcut(x) 65 | out = F.relu(out) 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, num_classes=10): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 64 73 | 74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 77 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 78 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 79 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 80 | self.linear = nn.Linear(512*block.expansion, num_classes) 81 | 82 | def _make_layer(self, block, planes, num_blocks, stride): 83 | strides = [stride] + [1]*(num_blocks-1) 84 | layers = [] 85 | for stride in strides: 86 | layers.append(block(self.in_planes, planes, stride)) 87 | self.in_planes = planes * block.expansion 88 | return nn.Sequential(*layers) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = self.layer1(out) 93 | out = self.layer2(out) 94 | out = self.layer3(out) 95 | out = self.layer4(out) 96 | # out = F.avg_pool2d(out, 4) 97 | out = F.adaptive_avg_pool2d(out, output_size=(1, 1)) # make it suitable for 64 * 64 input (TinyImageNet) 98 | out = out.view(out.size(0), -1) 99 | out = self.linear(out) 100 | return out 101 | 102 | 103 | def ResNet18(num_class=10): 104 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_class) 105 | 106 | 107 | def ResNet34(num_class=10): 108 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_class) 109 | 110 | 111 | def ResNet50(num_class=10): 112 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_class) 113 | 114 | 115 | def ResNet101(num_class=10): 116 | return ResNet(Bottleneck, [3,4,23,3], num_classes=num_class) 117 | 118 | 119 | def ResNet152(num_class=10): 120 | return ResNet(Bottleneck, [3,8,36,3], num_classes=num_class) 121 | 122 | 123 | if __name__ == "__main__": 124 | import torch 125 | 126 | model = ResNet18(200) 127 | x = torch.randn(2,3,64,64) 128 | y = model(x) 129 | print(y.shape) 130 | 131 | x = torch.randn(2,3,32,32) 132 | y = model(x) 133 | print(y.shape) -------------------------------------------------------------------------------- /model/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py 8 | """ 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import init 12 | import numpy as np 13 | 14 | # __all__ = ['resnext'] 15 | 16 | class ResNeXtBottleneck(nn.Module): 17 | """ 18 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 19 | """ 20 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor): 21 | """ Constructor 22 | Args: 23 | in_channels: input channel dimensionality 24 | out_channels: output channel dimensionality 25 | stride: conv stride. Replaces pooling layer. 26 | cardinality: num of convolution groups. 27 | widen_factor: factor to reduce the input dimensionality before convolution. 28 | """ 29 | super(ResNeXtBottleneck, self).__init__() 30 | D = cardinality * out_channels // widen_factor 31 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) 32 | self.bn_reduce = nn.BatchNorm2d(D) 33 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 34 | self.bn = nn.BatchNorm2d(D) 35 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 36 | self.bn_expand = nn.BatchNorm2d(out_channels) 37 | 38 | self.shortcut = nn.Sequential() 39 | if in_channels != out_channels: 40 | self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)) 41 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels)) 42 | 43 | def forward(self, x): 44 | bottleneck = self.conv_reduce.forward(x) 45 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True) 46 | bottleneck = self.conv_conv.forward(bottleneck) 47 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True) 48 | bottleneck = self.conv_expand.forward(bottleneck) 49 | bottleneck = self.bn_expand.forward(bottleneck) 50 | residual = self.shortcut.forward(x) 51 | return F.relu(residual + bottleneck, inplace=True) 52 | 53 | 54 | class CifarResNeXt(nn.Module): 55 | """ 56 | ResNext optimized for the Cifar dataset, as specified in 57 | https://arxiv.org/pdf/1611.05431.pdf 58 | """ 59 | def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0): 60 | """ Constructor 61 | Args: 62 | cardinality: number of convolution groups. 63 | depth: number of layers. 64 | num_classes: number of classes 65 | widen_factor: factor to adjust the channel dimensionality 66 | """ 67 | super(CifarResNeXt, self).__init__() 68 | self.cardinality = cardinality 69 | self.depth = depth 70 | self.block_depth = (self.depth - 2) // 9 71 | self.widen_factor = widen_factor 72 | self.num_classes = num_classes 73 | self.output_size = 64 74 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor] 75 | 76 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 77 | self.bn_1 = nn.BatchNorm2d(64) 78 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) 79 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) 80 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) 81 | self.classifier = nn.Linear(1024, num_classes) 82 | init.kaiming_normal(self.classifier.weight) 83 | 84 | for key in self.state_dict(): 85 | if key.split('.')[-1] == 'weight': 86 | if 'conv' in key: 87 | init.kaiming_normal(self.state_dict()[key], mode='fan_out') 88 | if 'bn' in key: 89 | self.state_dict()[key][...] = 1 90 | elif key.split('.')[-1] == 'bias': 91 | self.state_dict()[key][...] = 0 92 | 93 | def block(self, name, in_channels, out_channels, pool_stride=2): 94 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 95 | Args: 96 | name: string name of the current block. 97 | in_channels: number of input channels 98 | out_channels: number of output channels 99 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 100 | Returns: a Module consisting of n sequential bottlenecks. 101 | """ 102 | block = nn.Sequential() 103 | for bottleneck in range(self.block_depth): 104 | name_ = '%s_bottleneck_%d' % (name, bottleneck) 105 | if bottleneck == 0: 106 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality, 107 | self.widen_factor)) 108 | else: 109 | block.add_module(name_, 110 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor)) 111 | return block 112 | 113 | def forward(self, x): 114 | x = self.conv_1_3x3.forward(x) 115 | x = F.relu(self.bn_1.forward(x), inplace=True) 116 | x = self.stage_1.forward(x) 117 | x = self.stage_2.forward(x) 118 | x = self.stage_3.forward(x) 119 | # x = F.avg_pool2d(x, 8, 1) 120 | x = F.adaptive_avg_pool2d(x, output_size=(1, 1)) # make it suitable for 64 * 64 input (TinyImageNet) 121 | x = x.view(-1, 1024) 122 | return self.classifier(x) 123 | 124 | 125 | if __name__ == '__main__': 126 | import torch 127 | model = CifarResNeXt(cardinality=8, depth=29, num_classes=200) 128 | x = torch.randn(2,3,64,64) 129 | y = model(x) 130 | print(y.shape) 131 | 132 | x = torch.randn(2,3,32,32) 133 | y = model(x) 134 | print(y.shape) -------------------------------------------------------------------------------- /model/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | """shufflenetv2 in pytorch 2 | [1] Ningning Ma, Xiangyu Zhang, Hai-Tao Zheng, Jian Sun 3 | ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design 4 | https://arxiv.org/abs/1807.11164 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | def channel_split(x, split): 13 | """split a tensor into two pieces along channel dimension 14 | Args: 15 | x: input tensor 16 | split:(int) channel size for each pieces 17 | """ 18 | assert x.size(1) == split * 2 19 | return torch.split(x, split, dim=1) 20 | 21 | 22 | def channel_shuffle(x, groups): 23 | """channel shuffle operation 24 | Args: 25 | x: input tensor 26 | groups: input branch number 27 | """ 28 | 29 | batch_size, channels, height, width = x.size() 30 | channels_per_group = int(channels / groups) 31 | 32 | x = x.view(batch_size, groups, channels_per_group, height, width) 33 | x = x.transpose(1, 2).contiguous() 34 | x = x.view(batch_size, -1, height, width) 35 | 36 | return x 37 | 38 | 39 | class ShuffleUnit(nn.Module): 40 | 41 | def __init__(self, in_channels, out_channels, stride): 42 | super().__init__() 43 | 44 | self.stride = stride 45 | self.in_channels = in_channels 46 | self.out_channels = out_channels 47 | 48 | if stride != 1 or in_channels != out_channels: 49 | self.residual = nn.Sequential( 50 | nn.Conv2d(in_channels, in_channels, 1), 51 | nn.BatchNorm2d(in_channels), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 54 | nn.BatchNorm2d(in_channels), 55 | nn.Conv2d(in_channels, int(out_channels / 2), 1), 56 | nn.BatchNorm2d(int(out_channels / 2)), 57 | nn.ReLU(inplace=True) 58 | ) 59 | 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 62 | nn.BatchNorm2d(in_channels), 63 | nn.Conv2d(in_channels, int(out_channels / 2), 1), 64 | nn.BatchNorm2d(int(out_channels / 2)), 65 | nn.ReLU(inplace=True) 66 | ) 67 | else: 68 | self.shortcut = nn.Sequential() 69 | 70 | in_channels = int(in_channels / 2) 71 | self.residual = nn.Sequential( 72 | nn.Conv2d(in_channels, in_channels, 1), 73 | nn.BatchNorm2d(in_channels), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 76 | nn.BatchNorm2d(in_channels), 77 | nn.Conv2d(in_channels, in_channels, 1), 78 | nn.BatchNorm2d(in_channels), 79 | nn.ReLU(inplace=True) 80 | ) 81 | 82 | def forward(self, x): 83 | 84 | if self.stride == 1 and self.out_channels == self.in_channels: 85 | shortcut, residual = channel_split(x, int(self.in_channels / 2)) 86 | else: 87 | shortcut = x 88 | residual = x 89 | 90 | shortcut = self.shortcut(shortcut) 91 | residual = self.residual(residual) 92 | x = torch.cat([shortcut, residual], dim=1) 93 | x = channel_shuffle(x, 2) 94 | 95 | return x 96 | 97 | 98 | class ShuffleNetV2(nn.Module): 99 | 100 | def __init__(self, ratio=1, class_num=100): 101 | super().__init__() 102 | if ratio == 0.5: 103 | out_channels = [48, 96, 192, 1024] 104 | elif ratio == 1: 105 | out_channels = [116, 232, 464, 1024] 106 | elif ratio == 1.5: 107 | out_channels = [176, 352, 704, 1024] 108 | elif ratio == 2: 109 | out_channels = [244, 488, 976, 2048] 110 | else: 111 | ValueError('unsupported ratio number') 112 | 113 | self.pre = nn.Sequential( 114 | nn.Conv2d(3, 24, 3, padding=1), 115 | nn.BatchNorm2d(24) 116 | ) 117 | 118 | self.stage2 = self._make_stage(24, out_channels[0], 3) 119 | self.stage3 = self._make_stage(out_channels[0], out_channels[1], 7) 120 | self.stage4 = self._make_stage(out_channels[1], out_channels[2], 3) 121 | self.conv5 = nn.Sequential( 122 | nn.Conv2d(out_channels[2], out_channels[3], 1), 123 | nn.BatchNorm2d(out_channels[3]), 124 | nn.ReLU(inplace=True) 125 | ) 126 | 127 | self.fc = nn.Linear(out_channels[3], class_num) 128 | 129 | def forward(self, x): 130 | x = self.pre(x) 131 | x = self.stage2(x) 132 | x = self.stage3(x) 133 | x = self.stage4(x) 134 | x = self.conv5(x) 135 | x = F.adaptive_avg_pool2d(x, 1) 136 | x = x.view(x.size(0), -1) 137 | x = self.fc(x) 138 | 139 | return x 140 | 141 | def _make_stage(self, in_channels, out_channels, repeat): 142 | layers = [] 143 | layers.append(ShuffleUnit(in_channels, out_channels, 2)) 144 | 145 | while repeat: 146 | layers.append(ShuffleUnit(out_channels, out_channels, 1)) 147 | repeat -= 1 148 | 149 | return nn.Sequential(*layers) 150 | 151 | 152 | def shufflenetv2(**kwargs): 153 | return ShuffleNetV2(**kwargs) 154 | 155 | 156 | if __name__ == '__main__': 157 | import torch 158 | model = CifarResNeXt(cardinality=8, depth=29, num_classes=200) 159 | x = torch.randn(2,3,64,64) 160 | y = model(x) 161 | print(y.shape) 162 | 163 | x = torch.randn(2,3,32,32) 164 | y = model(x) 165 | print(y.shape) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #torchvision==0.5.0 2 | numpy==1.17.2 3 | tqdm==4.36.1 4 | #torch==1.4.0 5 | scipy==1.3.1 6 | -------------------------------------------------------------------------------- /train_kd.py: -------------------------------------------------------------------------------- 1 | # train a student network distilling from teacher 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.optim import SGD, Adam 7 | 8 | 9 | from tqdm import tqdm 10 | import argparse 11 | import os 12 | import logging 13 | import numpy as np 14 | 15 | from utils.utils import RunningAverage, set_logger, Params 16 | from model import * 17 | from data_loader import fetch_dataloader 18 | 19 | 20 | # ************************** random seed ************************** 21 | seed = 0 22 | 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | 30 | 31 | # ************************** parameters ************************** 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--save_path', default='experiments/CIFAR10/kd_normal/cnn', type=str) 34 | parser.add_argument('--teacher_resume', default=None, type=str, 35 | help='If you specify the teacher resume here, we will use it instead of parameters from json file') 36 | parser.add_argument('--resume', default=None, type=str) 37 | parser.add_argument('--gpu_id', default=[0], type=int, nargs='+', help='id(s) for CUDA_VISIBLE_DEVICES') 38 | args = parser.parse_args() 39 | 40 | device_ids = args.gpu_id 41 | torch.cuda.set_device(device_ids[0]) 42 | 43 | 44 | def loss_fn_kd(outputs, labels, teacher_outputs, params): 45 | """ 46 | Compute the knowledge-distillation (KD) loss given outputs, labels. 47 | """ 48 | alpha = params.alpha 49 | T = params.temperature 50 | KD_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs/T, dim=1), 51 | F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \ 52 | nn.CrossEntropyLoss()(outputs, labels) * (1. - alpha) 53 | 54 | return KD_loss 55 | 56 | 57 | # ************************** training function ************************** 58 | def train_epoch_kd(model, t_model, optim, loss_fn_kd, data_loader, params): 59 | model.train() 60 | t_model.eval() 61 | loss_avg = RunningAverage() 62 | 63 | with tqdm(total=len(data_loader)) as t: # Use tqdm for progress bar 64 | for i, (train_batch, labels_batch) in enumerate(data_loader): 65 | if params.cuda: 66 | train_batch = train_batch.cuda() # (B,3,32,32) 67 | labels_batch = labels_batch.cuda() # (B,) 68 | 69 | # compute model output and loss 70 | output_batch = model(train_batch) # logit without SoftMax 71 | 72 | # get one batch output from teacher_outputs list 73 | with torch.no_grad(): 74 | output_teacher_batch = t_model(train_batch) # logit without SoftMax 75 | 76 | # CE(output, label) + KLdiv(output, teach_out) 77 | loss = loss_fn_kd(output_batch, labels_batch, output_teacher_batch, params) 78 | 79 | optim.zero_grad() 80 | loss.backward() 81 | optim.step() 82 | 83 | # update the average loss 84 | loss_avg.update(loss.item()) 85 | 86 | # tqdm setting 87 | t.set_postfix(loss='{:05.3f}'.format(loss_avg())) 88 | t.update() 89 | return loss_avg() 90 | 91 | 92 | def evaluate(model, loss_fn, data_loader, params): 93 | model.eval() 94 | # summary for current eval loop 95 | summ = [] 96 | 97 | with torch.no_grad(): 98 | # compute metrics over the dataset 99 | for data_batch, labels_batch in data_loader: 100 | if params.cuda: 101 | data_batch = data_batch.cuda() # (B,3,32,32) 102 | labels_batch = labels_batch.cuda() # (B,) 103 | 104 | # compute model output 105 | output_batch = model(data_batch) 106 | loss = loss_fn(output_batch, labels_batch) 107 | 108 | # extract data from torch Variable, move to cpu, convert to numpy arrays 109 | output_batch = output_batch.cpu().numpy() 110 | labels_batch = labels_batch.cpu().numpy() 111 | # calculate accuracy 112 | output_batch = np.argmax(output_batch, axis=1) 113 | acc = 100.0 * np.sum(output_batch == labels_batch) / float(labels_batch.shape[0]) 114 | 115 | summary_batch = {'acc': acc, 'loss': loss.item()} 116 | summ.append(summary_batch) 117 | 118 | # compute mean of all metrics in summary 119 | metrics_mean = {metric: np.mean([x[metric] for x in summ]) for metric in summ[0]} 120 | return metrics_mean 121 | 122 | 123 | def train_and_eval_kd(model, t_model, optim, loss_fn, train_loader, dev_loader, params): 124 | best_val_acc = -1 125 | best_epo = -1 126 | lr = params.learning_rate 127 | 128 | for epoch in range(params.num_epochs): 129 | # LR schedule ***************** 130 | lr = adjust_learning_rate(optim, epoch, lr, params) 131 | 132 | logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs)) 133 | logging.info('Learning Rate {}'.format(lr)) 134 | 135 | # ********************* one full pass over the training set ********************* 136 | train_loss = train_epoch_kd(model, t_model, optim, loss_fn, train_loader, params) 137 | logging.info("- Train loss : {:05.3f}".format(train_loss)) 138 | 139 | # ********************* Evaluate for one epoch on validation set ********************* 140 | val_metrics = evaluate(model, nn.CrossEntropyLoss(), dev_loader, params) # {'acc':acc, 'loss':loss} 141 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in val_metrics.items()) 142 | logging.info("- Eval metrics : " + metrics_string) 143 | 144 | # save model 145 | save_name = os.path.join(args.save_path, 'last_model.tar') 146 | torch.save({ 147 | 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optim.state_dict()}, 148 | save_name) 149 | 150 | # ********************* get the best validation accuracy ********************* 151 | val_acc = val_metrics['acc'] 152 | if val_acc >= best_val_acc: 153 | best_epo = epoch + 1 154 | best_val_acc = val_acc 155 | logging.info('- New best model ') 156 | # save best model 157 | save_name = os.path.join(args.save_path, 'best_model.tar') 158 | torch.save({ 159 | 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optim.state_dict()}, 160 | save_name) 161 | 162 | logging.info('- So far best epoch: {}, best acc: {:05.3f}'.format(best_epo, best_val_acc)) 163 | 164 | 165 | def adjust_learning_rate(opt, epoch, lr, params): 166 | if epoch in params.schedule: 167 | lr = lr * params.gamma 168 | for param_group in opt.param_groups: 169 | param_group['lr'] = lr 170 | return lr 171 | 172 | 173 | if __name__ == "__main__": 174 | # ************************** set log ************************** 175 | set_logger(os.path.join(args.save_path, 'training.log')) 176 | 177 | # #################### Load the parameters from json file ##################################### 178 | json_path = os.path.join(args.save_path, 'params.json') 179 | assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path) 180 | params = Params(json_path) 181 | 182 | params.cuda = torch.cuda.is_available() # use GPU if available 183 | 184 | for k, v in params.__dict__.items(): 185 | logging.info('{}:{}'.format(k, v)) 186 | 187 | # ########################################## Dataset ########################################## 188 | trainloader = fetch_dataloader('train', params) 189 | devloader = fetch_dataloader('dev', params) 190 | 191 | # ############################################ Model ############################################ 192 | if params.dataset == 'cifar10': 193 | num_class = 10 194 | elif params.dataset == 'cifar100': 195 | num_class = 100 196 | elif params.dataset == 'tiny_imagenet': 197 | num_class = 200 198 | else: 199 | num_class = 10 200 | 201 | logging.info('Number of class: ' + str(num_class)) 202 | 203 | # ############################### Student Model ############################### 204 | logging.info('Create Student Model --- ' + params.model_name) 205 | 206 | # ResNet 18 / 34 / 50 **************************************** 207 | if params.model_name == 'resnet18': 208 | model = ResNet18(num_class=num_class) 209 | elif params.model_name == 'resnet34': 210 | model = ResNet34(num_class=num_class) 211 | elif params.model_name == 'resnet50': 212 | model = ResNet50(num_class=num_class) 213 | 214 | # PreResNet(ResNet for CIFAR-10) 20/32/56/110 *************** 215 | elif params.model_name.startswith('preresnet20'): 216 | model = PreResNet(depth=20, num_classes=num_class) 217 | elif params.model_name.startswith('preresnet32'): 218 | model = PreResNet(depth=32, num_classes=num_class) 219 | elif params.model_name.startswith('preresnet44'): 220 | model = PreResNet(depth=44, num_classes=num_class) 221 | elif params.model_name.startswith('preresnet56'): 222 | model = PreResNet(depth=56, num_classes=num_class) 223 | elif params.model_name.startswith('preresnet110'): 224 | model = PreResNet(depth=110, num_classes=num_class) 225 | 226 | 227 | # DenseNet ********************************************* 228 | elif params.model_name == 'densenet121': 229 | model = densenet121(num_class=num_class) 230 | elif params.model_name == 'densenet161': 231 | model = densenet161(num_class=num_class) 232 | elif params.model_name == 'densenet169': 233 | model = densenet169(num_class=num_class) 234 | 235 | # ResNeXt ********************************************* 236 | elif params.model_name == 'resnext29': 237 | model = CifarResNeXt(cardinality=8, depth=29, num_classes=num_class) 238 | 239 | elif params.model_name == 'mobilenetv2': 240 | model = MobileNetV2(class_num=num_class) 241 | 242 | elif params.model_name == 'shufflenetv2': 243 | model = shufflenetv2(class_num=num_class) 244 | 245 | # Basic neural network ******************************** 246 | elif params.model_name == 'net': 247 | model = Net(num_class, params) 248 | 249 | elif params.model_name == 'mlp': 250 | model = MLP(num_class=num_class) 251 | 252 | else: 253 | model = None 254 | print('Not support for model ' + str(params.model_name)) 255 | exit() 256 | 257 | # ############################### Teacher Model ############################### 258 | logging.info('Create Teacher Model --- ' + params.teacher_model) 259 | # ResNet 18 / 34 / 50 **************************************** 260 | if params.teacher_model == 'resnet18': 261 | teacher_model = ResNet18(num_class=num_class) 262 | elif params.teacher_model == 'resnet34': 263 | teacher_model = ResNet34(num_class=num_class) 264 | elif params.teacher_model == 'resnet50': 265 | teacher_model = ResNet50(num_class=num_class) 266 | 267 | # PreResNet(ResNet for CIFAR-10) 20/32/56/110 *************** 268 | elif params.teacher_model.startswith('preresnet20'): 269 | teacher_model = PreResNet(depth=20) 270 | elif params.teacher_model.startswith('preresnet32'): 271 | teacher_model = PreResNet(depth=32) 272 | elif params.teacher_model.startswith('preresnet56'): 273 | teacher_model = PreResNet(depth=56) 274 | elif params.teacher_model.startswith('preresnet110'): 275 | teacher_model = PreResNet(depth=110) 276 | 277 | # DenseNet ********************************************* 278 | elif params.teacher_model == 'densenet121': 279 | teacher_model = densenet121(num_class=num_class) 280 | elif params.teacher_model == 'densenet161': 281 | teacher_model = densenet161(num_class=num_class) 282 | elif params.teacher_model == 'densenet169': 283 | teacher_model = densenet169(num_class=num_class) 284 | 285 | # ResNeXt ********************************************* 286 | elif params.teacher_model == 'resnext29': 287 | teacher_model = CifarResNeXt(cardinality=8, depth=29, num_classes=num_class) 288 | 289 | elif params.teacher_model == 'mobilenetv2': 290 | teacher_model = MobileNetV2(class_num=num_class) 291 | 292 | elif params.teacher_model == 'shufflenetv2': 293 | teacher_model = shufflenetv2(class_num=num_class) 294 | 295 | elif params.teacher_model == 'net': 296 | teacher_model = Net(num_class, args) 297 | 298 | elif params.teacher_model == 'mlp': 299 | teacher_model = MLP(num_class=num_class) 300 | 301 | else: 302 | teacher_model = None 303 | exit() 304 | 305 | if params.cuda: 306 | model = model.cuda() 307 | teacher_model = teacher_model.cuda() 308 | 309 | if len(args.gpu_id) > 1: 310 | model = nn.DataParallel(model, device_ids=device_ids) 311 | teacher_model = nn.DataParallel(teacher_model, device_ids=device_ids) 312 | 313 | # checkpoint ******************************** 314 | if args.resume: 315 | logging.info('- Load checkpoint model from {}'.format(args.resume)) 316 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) 317 | model.load_state_dict(checkpoint['state_dict']) 318 | else: 319 | logging.info('- Train from scratch ') 320 | 321 | # load teacher model 322 | if args.teacher_resume: 323 | teacher_resume = args.teacher_resume 324 | logging.info('------ Teacher Resume from system parameters!') 325 | else: 326 | teacher_resume = params.teacher_resume 327 | logging.info('- Load Trained teacher model from {}'.format(teacher_resume)) 328 | checkpoint = torch.load(teacher_resume) 329 | teacher_model.load_state_dict(checkpoint['state_dict']) 330 | 331 | # ############################### Optimizer ############################### 332 | if params.model_name == 'net' or params.model_name == 'mlp': 333 | optimizer = Adam(model.parameters(), lr=params.learning_rate) 334 | logging.info('Optimizer: Adam') 335 | else: 336 | optimizer = SGD(model.parameters(), lr=params.learning_rate, momentum=0.9, weight_decay=5e-4) 337 | logging.info('Optimizer: SGD') 338 | 339 | # ************************** LOSS ************************** 340 | criterion = loss_fn_kd 341 | 342 | # ************************** Teacher ACC ************************** 343 | logging.info("- Teacher Model Evaluation ....") 344 | val_metrics = evaluate(teacher_model, nn.CrossEntropyLoss(), devloader, params) # {'acc':acc, 'loss':loss} 345 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in val_metrics.items()) 346 | logging.info("- Teacher Model Eval metrics : " + metrics_string) 347 | 348 | # ************************** train and evaluate ************************** 349 | train_and_eval_kd(model, teacher_model, optimizer, criterion, trainloader, devloader, params) 350 | 351 | 352 | -------------------------------------------------------------------------------- /train_nasty.py: -------------------------------------------------------------------------------- 1 | # train a nasty teacher with an adversarial network 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.optim import SGD, Adam 7 | 8 | 9 | from tqdm import tqdm 10 | import argparse 11 | import os 12 | import logging 13 | import numpy as np 14 | 15 | from utils.utils import RunningAverage, set_logger, Params 16 | from model import * 17 | from data_loader import fetch_dataloader 18 | 19 | # ************************** random seed ************************** 20 | seed = 0 21 | 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | 29 | 30 | # ************************** parameters ************************** 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--save_path', default='experiments/CIFAR10/adversarial_teacher/resnet18_self', type=str) 33 | parser.add_argument('--resume', default=None, type=str) 34 | parser.add_argument('--gpu_id', default=[0], type=int, nargs='+', help='id(s) for CUDA_VISIBLE_DEVICES') 35 | args = parser.parse_args() 36 | 37 | device_ids = args.gpu_id 38 | torch.cuda.set_device(device_ids[0]) 39 | 40 | 41 | # ************************** training function ************************** 42 | def train_epoch_kd_adv(model, model_ad, optim, data_loader, epoch, params): 43 | model.train() 44 | model_ad.eval() 45 | tch_loss_avg = RunningAverage() 46 | ad_loss_avg = RunningAverage() 47 | loss_avg = RunningAverage() 48 | 49 | with tqdm(total=len(data_loader)) as t: # Use tqdm for progress bar 50 | for i, (train_batch, labels_batch) in enumerate(data_loader): 51 | if params.cuda: 52 | train_batch = train_batch.cuda() # (B,3,32,32) 53 | labels_batch = labels_batch.cuda() # (B,) 54 | 55 | # compute (teacher) model output and loss 56 | output_tch = model(train_batch) # logit without SoftMax 57 | 58 | # teacher loss: CE(output_tch, label) 59 | tch_loss = nn.CrossEntropyLoss()(output_tch, labels_batch) 60 | 61 | # ############ adversarial loss #################################### 62 | # computer adversarial model output 63 | with torch.no_grad(): 64 | output_stu = model_ad(train_batch) # logit without SoftMax 65 | output_stu = output_stu.detach() 66 | 67 | # adversarial loss: KLdiv(output_stu, output_tch) 68 | T = params.temperature 69 | adv_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(output_stu / T, dim=1), 70 | F.softmax(output_tch / T, dim=1)) * (T * T) # wish to max this item 71 | 72 | # total loss 73 | loss = tch_loss - params.weight * adv_loss + 100.0 # make the loss positive by adding a constant 74 | 75 | # ############################################################ 76 | 77 | optim.zero_grad() 78 | loss.backward() 79 | optim.step() 80 | 81 | # update the average loss 82 | loss_avg.update(loss.item()) 83 | tch_loss_avg.update(tch_loss.item()) 84 | ad_loss_avg.update(adv_loss.item()) 85 | 86 | # tqdm setting 87 | t.set_postfix(loss='{:05.3f}'.format(loss_avg())) 88 | t.update() 89 | return loss_avg(), tch_loss_avg(), ad_loss_avg() 90 | 91 | 92 | def evaluate(model, loss_fn, data_loader, params): 93 | model.eval() 94 | # summary for current eval loop 95 | summ = [] 96 | 97 | with torch.no_grad(): 98 | # compute metrics over the dataset 99 | for data_batch, labels_batch in data_loader: 100 | if params.cuda: 101 | data_batch = data_batch.cuda() # (B,3,32,32) 102 | labels_batch = labels_batch.cuda() # (B,) 103 | 104 | # compute model output 105 | output_batch = model(data_batch) 106 | loss = loss_fn(output_batch, labels_batch) 107 | 108 | # extract data from torch Variable, move to cpu, convert to numpy arrays 109 | output_batch = output_batch.cpu().numpy() 110 | labels_batch = labels_batch.cpu().numpy() 111 | # calculate accuracy 112 | output_batch = np.argmax(output_batch, axis=1) 113 | acc = 100.0 * np.sum(output_batch == labels_batch) / float(labels_batch.shape[0]) 114 | 115 | summary_batch = {'acc': acc, 'loss': loss.item()} 116 | summ.append(summary_batch) 117 | 118 | # compute mean of all metrics in summary 119 | metrics_mean = {metric: np.mean([x[metric] for x in summ]) for metric in summ[0]} 120 | return metrics_mean 121 | 122 | 123 | def train_and_eval_kd_adv(model, model_ad, optim, train_loader, dev_loader, params): 124 | best_val_acc = -1 125 | best_epo = -1 126 | lr = params.learning_rate 127 | 128 | for epoch in range(params.num_epochs): 129 | lr = adjust_learning_rate(optim, epoch, lr, params) 130 | logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs)) 131 | logging.info('Learning Rate {}'.format(lr)) 132 | 133 | # ********************* one full pass over the training set ********************* 134 | train_loss, train_tloss, train_aloss = train_epoch_kd_adv(model, model_ad, optim, 135 | train_loader, epoch, params) 136 | logging.info("- Train loss : {:05.3f}".format(train_loss)) 137 | logging.info("- Train teacher loss : {:05.3f}".format(train_tloss)) 138 | logging.info("- Train adversarial loss : {:05.3f}".format(train_aloss)) 139 | 140 | # ********************* Evaluate for one epoch on validation set ********************* 141 | val_metrics = evaluate(model, nn.CrossEntropyLoss(), dev_loader, params) # {'acc':acc, 'loss':loss} 142 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in val_metrics.items()) 143 | logging.info("- Eval metrics : " + metrics_string) 144 | 145 | # save model 146 | save_name = os.path.join(args.save_path, 'last_model.tar') 147 | torch.save({ 148 | 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optim.state_dict()}, 149 | save_name) 150 | 151 | # ********************* get the best validation accuracy ********************* 152 | val_acc = val_metrics['acc'] 153 | if val_acc >= best_val_acc: 154 | best_epo = epoch + 1 155 | best_val_acc = val_acc 156 | logging.info('- New best model ') 157 | # save best model 158 | save_name = os.path.join(args.save_path, 'best_model.tar') 159 | torch.save({ 160 | 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optim.state_dict()}, 161 | save_name) 162 | 163 | logging.info('- So far best epoch: {}, best acc: {:05.3f}'.format(best_epo, best_val_acc)) 164 | 165 | 166 | def adjust_learning_rate(opt, epoch, lr, params): 167 | if epoch in params.schedule: 168 | lr = lr * params.gamma 169 | for param_group in opt.param_groups: 170 | param_group['lr'] = lr 171 | return lr 172 | 173 | 174 | if __name__ == "__main__": 175 | # ************************** set log ************************** 176 | set_logger(os.path.join(args.save_path, 'training.log')) 177 | 178 | # #################### Load the parameters from json file ##################################### 179 | json_path = os.path.join(args.save_path, 'params.json') 180 | assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path) 181 | params = Params(json_path) 182 | 183 | params.cuda = torch.cuda.is_available() # use GPU if available 184 | 185 | for k, v in params.__dict__.items(): 186 | logging.info('{}:{}'.format(k, v)) 187 | 188 | # ########################################## Dataset ########################################## 189 | trainloader = fetch_dataloader('train', params) 190 | devloader = fetch_dataloader('dev', params) 191 | 192 | # ############################################ Model ############################################ 193 | if params.dataset == 'cifar10': 194 | num_class = 10 195 | elif params.dataset == 'cifar100': 196 | num_class = 100 197 | elif params.dataset == 'tiny_imagenet': 198 | num_class = 200 199 | else: 200 | num_class = 10 201 | 202 | logging.info('Number of class: ' + str(num_class)) 203 | 204 | logging.info('Create Model --- ' + params.model_name) 205 | 206 | # ResNet 18 / 34 / 50 **************************************** 207 | if params.model_name == 'resnet18': 208 | model = ResNet18(num_class=num_class) 209 | elif params.model_name == 'resnet34': 210 | model = ResNet34(num_class=num_class) 211 | elif params.model_name == 'resnet50': 212 | model = ResNet50(num_class=num_class) 213 | 214 | # PreResNet(ResNet for CIFAR-10) 20/32/56/110 *************** 215 | elif params.model_name.startswith('preresnet20'): 216 | model = PreResNet(depth=20, num_classes=num_class) 217 | elif params.model_name.startswith('preresnet32'): 218 | model = PreResNet(depth=32, num_classes=num_class) 219 | elif params.model_name.startswith('preresnet44'): 220 | model = PreResNet(depth=44, num_classes=num_class) 221 | elif params.model_name.startswith('preresnet56'): 222 | model = PreResNet(depth=56, num_classes=num_class) 223 | elif params.model_name.startswith('preresnet110'): 224 | model = PreResNet(depth=110, num_classes=num_class) 225 | 226 | 227 | # DenseNet ********************************************* 228 | elif params.model_name == 'densenet121': 229 | model = densenet121(num_class=num_class) 230 | elif params.model_name == 'densenet161': 231 | model = densenet161(num_class=num_class) 232 | elif params.model_name == 'densenet169': 233 | model = densenet169(num_class=num_class) 234 | 235 | # ResNeXt ********************************************* 236 | elif params.model_name == 'resnext29': 237 | model = CifarResNeXt(cardinality=8, depth=29, num_classes=num_class) 238 | 239 | elif params.model_name == 'mobilenetv2': 240 | model = MobileNetV2(class_num=num_class) 241 | 242 | elif params.model_name == 'shufflenetv2': 243 | model = shufflenetv2(class_num=num_class) 244 | 245 | # Basic neural network ******************************** 246 | elif params.model_name == 'net': 247 | model = Net(num_class, params) 248 | 249 | elif params.model_name == 'mlp': 250 | model = MLP(num_class=num_class) 251 | 252 | else: 253 | model = None 254 | print('Not support for model ' + str(params.model_name)) 255 | exit() 256 | 257 | # Adversarial model ************************************************************* 258 | logging.info('Create Adversarial Model --- ' + params.adversarial_model) 259 | 260 | # ResNet 18 / 34 / 50 **************************************** 261 | if params.adversarial_model == 'resnet18': 262 | adversarial_model = ResNet18(num_class=num_class) 263 | elif params.adversarial_model == 'resnet34': 264 | adversarial_model = ResNet34(num_class=num_class) 265 | elif params.adversarial_model == 'resnet50': 266 | adversarial_model = ResNet50(num_class=num_class) 267 | 268 | # PreResNet(ResNet for CIFAR-10) 20/32/56/110 *************** 269 | elif params.adversarial_model.startswith('preresnet20'): 270 | adversarial_model = PreResNet(depth=20) 271 | elif params.adversarial_model.startswith('preresnet32'): 272 | adversarial_model = PreResNet(depth=32) 273 | elif params.adversarial_model.startswith('preresnet56'): 274 | adversarial_model = PreResNet(depth=56) 275 | elif params.adversarial_model.startswith('preresnet110'): 276 | adversarial_model = PreResNet(depth=110) 277 | 278 | # DenseNet ********************************************* 279 | elif params.adversarial_model == 'densenet121': 280 | adversarial_model = densenet121(num_class=num_class) 281 | elif params.adversarial_model == 'densenet161': 282 | adversarial_model = densenet161(num_class=num_class) 283 | elif params.adversarial_model == 'densenet169': 284 | adversarial_model = densenet169(num_class=num_class) 285 | 286 | # ResNeXt ********************************************* 287 | elif params.adversarial_model == 'resnext29': 288 | adversarial_model = CifarResNeXt(cardinality=8, depth=29, num_classes=num_class) 289 | 290 | elif params.adversarial_model == 'mobilenetv2': 291 | adversarial_model = MobileNetV2(class_num=num_class) 292 | 293 | elif params.adversarial_model == 'shufflenetv2': 294 | adversarial_model = shufflenetv2(class_num=num_class) 295 | 296 | # Basic neural network ******************************** 297 | elif params.adversarial_model == 'net': 298 | adversarial_model = Net(num_class, params) 299 | 300 | elif params.adversarial_model == 'mlp': 301 | adversarial_model = MLP(num_class=num_class) 302 | 303 | else: 304 | adversarial_model = None 305 | print('Not support for model ' + str(params.adversarial_model)) 306 | exit() 307 | 308 | if params.cuda: 309 | model = model.cuda() 310 | adversarial_model = adversarial_model.cuda() 311 | 312 | if len(args.gpu_id) > 1: 313 | model = nn.DataParallel(model, device_ids=device_ids) 314 | adversarial_model = nn.DataParallel(adversarial_model, device_ids=device_ids) 315 | 316 | # checkpoint ******************************** 317 | if args.resume: 318 | logging.info('- Load checkpoint from {}'.format(args.resume)) 319 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) 320 | model.load_state_dict(checkpoint['state_dict']) 321 | else: 322 | logging.info('- Train from scratch ') 323 | 324 | # load trained Adversarial model **************************** 325 | logging.info('- Load Trained adversarial model from {}'.format(params.adversarial_resume)) 326 | checkpoint = torch.load(params.adversarial_resume) 327 | adversarial_model.load_state_dict(checkpoint['state_dict']) 328 | 329 | # ############################### Optimizer ############################### 330 | if params.model_name == 'net' or params.model_name == 'mlp': 331 | optimizer = Adam(model.parameters(), lr=params.learning_rate) 332 | logging.info('Optimizer: Adam') 333 | else: 334 | optimizer = SGD(model.parameters(), lr=params.learning_rate, momentum=0.9, weight_decay=5e-4) 335 | logging.info('Optimizer: SGD') 336 | 337 | # ************************** train and evaluate ************************** 338 | train_and_eval_kd_adv(model, adversarial_model, optimizer, trainloader, devloader, params) 339 | 340 | -------------------------------------------------------------------------------- /train_scratch.py: -------------------------------------------------------------------------------- 1 | # train a baseline model from scratch 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.optim import SGD, Adam 6 | 7 | from tqdm import tqdm 8 | import argparse 9 | import os 10 | import logging 11 | import numpy as np 12 | 13 | from utils.utils import RunningAverage, set_logger, Params 14 | from model import * 15 | from data_loader import fetch_dataloader 16 | 17 | 18 | # ************************** random seed ************************** 19 | seed = 0 20 | 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | 28 | # ************************** parameters ************************** 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--save_path', default='experiments/CIFAR10/baseline/resnet18', type=str) 31 | parser.add_argument('--resume', default=None, type=str) 32 | parser.add_argument('--gpu_id', default=[0], type=int, nargs='+', help='id(s) for CUDA_VISIBLE_DEVICES') 33 | args = parser.parse_args() 34 | 35 | device_ids = args.gpu_id 36 | torch.cuda.set_device(device_ids[0]) 37 | 38 | 39 | # ************************** training function ************************** 40 | def train_epoch(model, optim, loss_fn, data_loader, params): 41 | model.train() 42 | loss_avg = RunningAverage() 43 | 44 | with tqdm(total=len(data_loader)) as t: # Use tqdm for progress bar 45 | for i, (train_batch, labels_batch) in enumerate(data_loader): 46 | if params.cuda: 47 | train_batch = train_batch.cuda() # (B,3,32,32) 48 | labels_batch = labels_batch.cuda() # (B,) 49 | 50 | # compute model output and loss 51 | output_batch = model(train_batch) # logit without softmax 52 | loss = loss_fn(output_batch, labels_batch) 53 | 54 | optim.zero_grad() 55 | loss.backward() 56 | optim.step() 57 | 58 | # update the average loss 59 | loss_avg.update(loss.item()) 60 | 61 | # tqdm setting 62 | t.set_postfix(loss='{:05.3f}'.format(loss_avg())) 63 | t.update() 64 | return loss_avg() 65 | 66 | 67 | def evaluate(model, loss_fn, data_loader, params): 68 | model.eval() 69 | # summary for current eval loop 70 | summ = [] 71 | 72 | with torch.no_grad(): 73 | # compute metrics over the dataset 74 | for data_batch, labels_batch in data_loader: 75 | if params.cuda: 76 | data_batch = data_batch.cuda() # (B,3,32,32) 77 | labels_batch = labels_batch.cuda() # (B,) 78 | 79 | # compute model output 80 | output_batch = model(data_batch) 81 | loss = loss_fn(output_batch, labels_batch) 82 | 83 | # extract data from torch Variable, move to cpu, convert to numpy arrays 84 | output_batch = output_batch.cpu().numpy() 85 | labels_batch = labels_batch.cpu().numpy() 86 | # calculate accuracy 87 | output_batch = np.argmax(output_batch, axis=1) 88 | acc = 100.0 * np.sum(output_batch == labels_batch) / float(labels_batch.shape[0]) 89 | 90 | summary_batch = {'acc': acc, 'loss': loss.item()} 91 | summ.append(summary_batch) 92 | 93 | # compute mean of all metrics in summary 94 | metrics_mean = {metric: np.mean([x[metric] for x in summ]) for metric in summ[0]} 95 | return metrics_mean 96 | 97 | 98 | def train_and_eval(model, optim, loss_fn, train_loader, dev_loader, params): 99 | best_val_acc = -1 100 | best_epo = -1 101 | lr = params.learning_rate 102 | 103 | for epoch in range(params.num_epochs): 104 | # LR schedule ***************** 105 | lr = adjust_learning_rate(optim, epoch, lr, params) 106 | 107 | logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs)) 108 | logging.info('Learning Rate {}'.format(lr)) 109 | 110 | # ********************* one full pass over the training set ********************* 111 | train_loss = train_epoch(model, optim, loss_fn, train_loader, params) 112 | logging.info("- Train loss : {:05.3f}".format(train_loss)) 113 | 114 | # ********************* Evaluate for one epoch on validation set ********************* 115 | val_metrics = evaluate(model, loss_fn, dev_loader, params) # {'acc':acc, 'loss':loss} 116 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in val_metrics.items()) 117 | logging.info("- Eval metrics : " + metrics_string) 118 | 119 | # save last epoch model 120 | save_name = os.path.join(args.save_path, 'last_model.tar') 121 | torch.save({ 122 | 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optim.state_dict()}, 123 | save_name) 124 | 125 | # ********************* get the best validation accuracy ********************* 126 | val_acc = val_metrics['acc'] 127 | if val_acc >= best_val_acc: 128 | best_epo = epoch + 1 129 | best_val_acc = val_acc 130 | logging.info('- New best model ') 131 | # save best model 132 | save_name = os.path.join(args.save_path, 'best_model.tar') 133 | torch.save({ 134 | 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optim.state_dict()}, 135 | save_name) 136 | 137 | logging.info('- So far best epoch: {}, best acc: {:05.3f}'.format(best_epo, best_val_acc)) 138 | 139 | 140 | def adjust_learning_rate(opt, epoch, lr, params): 141 | if epoch in params.schedule: 142 | lr = lr * params.gamma 143 | for param_group in opt.param_groups: 144 | param_group['lr'] = lr 145 | return lr 146 | 147 | 148 | if __name__ == "__main__": 149 | # ************************** set log ************************** 150 | set_logger(os.path.join(args.save_path, 'training.log')) 151 | 152 | # #################### Load the parameters from json file ##################################### 153 | json_path = os.path.join(args.save_path, 'params.json') 154 | assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path) 155 | params = Params(json_path) 156 | 157 | params.cuda = torch.cuda.is_available() # use GPU if available 158 | 159 | for k, v in params.__dict__.items(): 160 | logging.info('{}:{}'.format(k, v)) 161 | 162 | # ########################################## Dataset ########################################## 163 | trainloader = fetch_dataloader('train', params) 164 | devloader = fetch_dataloader('dev', params) 165 | 166 | # ############################################ Model ############################################ 167 | if params.dataset == 'cifar10': 168 | num_class = 10 169 | elif params.dataset == 'cifar100': 170 | num_class = 100 171 | elif params.dataset == 'tiny_imagenet': 172 | num_class = 200 173 | else: 174 | num_class = 10 175 | 176 | logging.info('Number of class: ' + str(num_class)) 177 | logging.info('Create Model --- ' + params.model_name) 178 | 179 | # ResNet 18 / 34 / 50 **************************************** 180 | if params.model_name == 'resnet18': 181 | model = ResNet18(num_class=num_class) 182 | elif params.model_name == 'resnet34': 183 | model = ResNet34(num_class=num_class) 184 | elif params.model_name == 'resnet50': 185 | model = ResNet50(num_class=num_class) 186 | 187 | # PreResNet(ResNet for CIFAR-10) 20/32/56/110 *************** 188 | elif params.model_name.startswith('preresnet20'): 189 | model = PreResNet(depth=20, num_classes=num_class) 190 | elif params.model_name.startswith('preresnet32'): 191 | model = PreResNet(depth=32, num_classes=num_class) 192 | elif params.model_name.startswith('preresnet44'): 193 | model = PreResNet(depth=44, num_classes=num_class) 194 | elif params.model_name.startswith('preresnet56'): 195 | model = PreResNet(depth=56, num_classes=num_class) 196 | elif params.model_name.startswith('preresnet110'): 197 | model = PreResNet(depth=110, num_classes=num_class) 198 | 199 | # DenseNet ********************************************* 200 | elif params.model_name == 'densenet121': 201 | model = densenet121(num_class=num_class) 202 | elif params.model_name == 'densenet161': 203 | model = densenet161(num_class=num_class) 204 | elif params.model_name == 'densenet169': 205 | model = densenet169(num_class=num_class) 206 | 207 | # ResNeXt ********************************************* 208 | elif params.model_name == 'resnext29': 209 | model = CifarResNeXt(cardinality=8, depth=29, num_classes=num_class) 210 | 211 | elif params.model_name == 'mobilenetv2': 212 | model = MobileNetV2(class_num=num_class) 213 | 214 | elif params.model_name == 'shufflenetv2': 215 | model = shufflenetv2(class_num=num_class) 216 | 217 | # Basic neural network ******************************** 218 | elif params.model_name == 'net': 219 | model = Net(num_class, params) 220 | 221 | elif params.model_name == 'mlp': 222 | model = MLP(num_class=num_class) 223 | 224 | else: 225 | model = None 226 | print('Not support for model ' + str(params.model_name)) 227 | exit() 228 | 229 | if params.cuda: 230 | model = model.cuda() 231 | 232 | if len(args.gpu_id) > 1: 233 | model = nn.DataParallel(model, device_ids=device_ids) 234 | 235 | # checkpoint ******************************** 236 | if args.resume: 237 | logging.info('- Load checkpoint model from {}'.format(args.resume)) 238 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) 239 | model.load_state_dict(checkpoint['state_dict']) 240 | else: 241 | logging.info('- Train from scratch ') 242 | 243 | # ############################### Optimizer ############################### 244 | if params.model_name == 'net' or params.model_name == 'mlp': 245 | optimizer = Adam(model.parameters(), lr=params.learning_rate) 246 | logging.info('Optimizer: Adam') 247 | else: 248 | optimizer = SGD(model.parameters(), lr=params.learning_rate, momentum=0.9, weight_decay=5e-4) 249 | logging.info('Optimizer: SGD') 250 | 251 | # ************************** LOSS ************************** 252 | criterion = nn.CrossEntropyLoss() 253 | 254 | # ################################# train and evaluate ################################# 255 | train_and_eval(model, optimizer, criterion, trainloader, devloader, params) 256 | 257 | 258 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tensorboard logger code referenced from: 3 | https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/04-utils/ 4 | Other helper functions: 5 | https://github.com/cs230-stanford/cs230-stanford.github.io 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import torch 13 | 14 | import numpy as np 15 | import scipy.misc 16 | 17 | try: 18 | from StringIO import StringIO # Python 2.7 19 | except ImportError: 20 | from io import BytesIO # Python 3.x 21 | 22 | 23 | class Params(): 24 | """Class that loads hyperparameters from a json file. 25 | 26 | Example: 27 | ``` 28 | params = Params(json_path) 29 | print(params.learning_rate) 30 | params.learning_rate = 0.5 # change the value of learning_rate in params 31 | ``` 32 | """ 33 | 34 | def __init__(self, json_path): 35 | with open(json_path) as f: 36 | params = json.load(f) 37 | self.__dict__.update(params) 38 | 39 | def save(self, json_path): 40 | with open(json_path, 'w') as f: 41 | json.dump(self.__dict__, f, indent=4) 42 | 43 | def update(self, json_path): 44 | """Loads parameters from json file""" 45 | with open(json_path) as f: 46 | params = json.load(f) 47 | self.__dict__.update(params) 48 | 49 | @property 50 | def dict(self): 51 | """Gives dict-like access to Params instance by `params.dict['learning_rate']""" 52 | return self.__dict__ 53 | 54 | 55 | class RunningAverage(): 56 | """A simple class that maintains the running average of a quantity 57 | 58 | Example: 59 | ``` 60 | loss_avg = RunningAverage() 61 | loss_avg.update(2) 62 | loss_avg.update(4) 63 | loss_avg() = 3 64 | ``` 65 | """ 66 | 67 | def __init__(self): 68 | self.steps = 0 69 | self.total = 0 70 | 71 | def update(self, val): 72 | self.total += val 73 | self.steps += 1 74 | 75 | def __call__(self): 76 | return self.total / float(self.steps) 77 | 78 | 79 | def set_logger(log_path): 80 | """Set the logger to log info in terminal and file `log_path`. 81 | 82 | In general, it is useful to have a logger so that every output to the terminal is saved 83 | in a permanent file. Here we save it to `model_dir/train.log`. 84 | 85 | Example: 86 | ``` 87 | logging.info("Starting training...") 88 | ``` 89 | 90 | Args: 91 | log_path: (string) where to log 92 | """ 93 | logger = logging.getLogger() 94 | logger.setLevel(logging.INFO) 95 | 96 | if not logger.handlers: 97 | # Logging to a file 98 | file_handler = logging.FileHandler(log_path) 99 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 100 | logger.addHandler(file_handler) 101 | 102 | # Logging to console 103 | stream_handler = logging.StreamHandler() 104 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 105 | logger.addHandler(stream_handler) 106 | 107 | 108 | --------------------------------------------------------------------------------