├── README.md ├── datasets.py ├── models ├── __init__.py ├── densenet.py ├── densenet3.py └── resnet.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Regularizing Class-wise Predictions via Self-knowledge Distillation (CS-KD) 2 | 3 | PyTorch implementation of ["Regularizing Class-wise Predictions via Self-knowledge Distillation"](https://arxiv.org/abs/2003.13964) (CVPR 2020). 4 | 5 | ## Requirements 6 | 7 | `torch==1.2.0`, `torchvision==0.4.0` 8 | 9 | ## Run experiments 10 | 11 | train cifar100 on resnet with class-wise regularization losses 12 | 13 | `python3 train.py --sgpu 0 --lr 0.1 --epoch 200 --model CIFAR_ResNet18 --name test_cifar --decay 1e-4 --dataset cifar100 --dataroot ~/data/ -cls --lamda 1` 14 | 15 | train fine-grained dataset on resnet with class-wise regularization losses 16 | 17 | `python3 train.py --sgpu 0 --lr 0.1 --epoch 200 --model resnet18 --name test_cub200 --batch-size 32 --decay 1e-4 --dataset CUB200 --dataroot ~/data/ -cls --lamda 3` 18 | 19 | ## Citation 20 | If you use this code for your research, please cite our papers. 21 | ``` 22 | @InProceedings{Yun_2020_CVPR, 23 | author = {Yun, Sukmin and Park, Jongjin and Lee, Kimin and Shin, Jinwoo}, 24 | title = {Regularizing Class-Wise Predictions via Self-Knowledge Distillation}, 25 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 26 | month = {June}, 27 | year = {2020} 28 | } 29 | ``` 30 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import csv, torchvision, numpy as np, random, os 2 | from PIL import Image 3 | 4 | from torch.utils.data import Sampler, Dataset, DataLoader, BatchSampler, SequentialSampler, RandomSampler, Subset 5 | from torchvision import transforms, datasets 6 | from collections import defaultdict 7 | 8 | 9 | class PairBatchSampler(Sampler): 10 | def __init__(self, dataset, batch_size, num_iterations=None): 11 | self.dataset = dataset 12 | self.batch_size = batch_size 13 | self.num_iterations = num_iterations 14 | 15 | def __iter__(self): 16 | indices = list(range(len(self.dataset))) 17 | random.shuffle(indices) 18 | for k in range(len(self)): 19 | if self.num_iterations is None: 20 | offset = k*self.batch_size 21 | batch_indices = indices[offset:offset+self.batch_size] 22 | else: 23 | batch_indices = random.sample(range(len(self.dataset)), 24 | self.batch_size) 25 | 26 | pair_indices = [] 27 | for idx in batch_indices: 28 | y = self.dataset.get_class(idx) 29 | pair_indices.append(random.choice(self.dataset.classwise_indices[y])) 30 | 31 | yield batch_indices + pair_indices 32 | 33 | def __len__(self): 34 | if self.num_iterations is None: 35 | return (len(self.dataset)+self.batch_size-1) // self.batch_size 36 | else: 37 | return self.num_iterations 38 | 39 | 40 | class DatasetWrapper(Dataset): 41 | # Additinoal attributes 42 | # - indices 43 | # - classwise_indices 44 | # - num_classes 45 | # - get_class 46 | 47 | def __init__(self, dataset, indices=None): 48 | self.base_dataset = dataset 49 | if indices is None: 50 | self.indices = list(range(len(dataset))) 51 | else: 52 | self.indices = indices 53 | 54 | # torchvision 0.2.0 compatibility 55 | if torchvision.__version__.startswith('0.2'): 56 | if isinstance(self.base_dataset, datasets.ImageFolder): 57 | self.base_dataset.targets = [s[1] for s in self.base_dataset.imgs] 58 | else: 59 | if self.base_dataset.train: 60 | self.base_dataset.targets = self.base_dataset.train_labels 61 | else: 62 | self.base_dataset.targets = self.base_dataset.test_labels 63 | 64 | self.classwise_indices = defaultdict(list) 65 | for i in range(len(self)): 66 | y = self.base_dataset.targets[self.indices[i]] 67 | self.classwise_indices[y].append(i) 68 | self.num_classes = max(self.classwise_indices.keys())+1 69 | 70 | def __getitem__(self, i): 71 | return self.base_dataset[self.indices[i]] 72 | 73 | def __len__(self): 74 | return len(self.indices) 75 | 76 | def get_class(self, i): 77 | return self.base_dataset.targets[self.indices[i]] 78 | 79 | 80 | class ConcatWrapper(Dataset): # TODO: Naming 81 | @staticmethod 82 | def cumsum(sequence): 83 | r, s = [], 0 84 | for e in sequence: 85 | l = len(e) 86 | r.append(l + s) 87 | s += l 88 | return r 89 | 90 | @staticmethod 91 | def numcls(sequence): 92 | s = 0 93 | for e in sequence: 94 | l = e.num_classes 95 | s += l 96 | return s 97 | 98 | @staticmethod 99 | def clsidx(sequence): 100 | r, s, n = defaultdict(list), 0, 0 101 | for e in sequence: 102 | l = e.classwise_indices 103 | for c in range(s, s + e.num_classes): 104 | t = np.asarray(l[c-s]) + n 105 | r[c] = t.tolist() 106 | s += e.num_classes 107 | n += len(e) 108 | return r 109 | 110 | def __init__(self, datasets): 111 | super(ConcatWrapper, self).__init__() 112 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 113 | self.datasets = list(datasets) 114 | # for d in self.datasets: 115 | # assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" 116 | self.cumulative_sizes = self.cumsum(self.datasets) 117 | 118 | self.num_classes = self.numcls(self.datasets) 119 | self.classwise_indices = self.clsidx(self.datasets) 120 | 121 | def __len__(self): 122 | return self.cumulative_sizes[-1] 123 | 124 | def __getitem__(self, idx): 125 | if idx < 0: 126 | if -idx > len(self): 127 | raise ValueError("absolute value of index should not exceed dataset length") 128 | idx = len(self) + idx 129 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 130 | if dataset_idx == 0: 131 | sample_idx = idx 132 | else: 133 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 134 | return self.datasets[dataset_idx][sample_idx] 135 | 136 | def get_class(self, idx): 137 | if idx < 0: 138 | if -idx > len(self): 139 | raise ValueError("absolute value of index should not exceed dataset length") 140 | idx = len(self) + idx 141 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 142 | if dataset_idx == 0: 143 | sample_idx = idx 144 | else: 145 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 146 | true_class = self.datasets[dataset_idx].base_dataset.targets[self.datasets[dataset_idx].indices[sample_idx]] 147 | return self.datasets[dataset_idx].base_dataset.target_transform(true_class) 148 | 149 | @property 150 | def cummulative_sizes(self): 151 | warnings.warn("cummulative_sizes attribute is renamed to " 152 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 153 | return self.cumulative_sizes 154 | 155 | 156 | 157 | def load_dataset(name, root, sample='default', **kwargs): 158 | # Dataset 159 | if name in ['imagenet','tinyimagenet', 'CUB200', 'STANFORD120', 'MIT67']: 160 | # TODO 161 | if name == 'tinyimagenet': 162 | transform_train = transforms.Compose([ 163 | transforms.RandomResizedCrop(32), 164 | transforms.RandomHorizontalFlip(), 165 | transforms.ToTensor(), 166 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 167 | ]) 168 | transform_test = transforms.Compose([ 169 | transforms.Resize(32), 170 | transforms.ToTensor(), 171 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 172 | ]) 173 | 174 | train_val_dataset_dir = os.path.join(root, "train") 175 | test_dataset_dir = os.path.join(root, "val") 176 | 177 | trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)) 178 | valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)) 179 | 180 | elif name == 'imagenet': 181 | transform_train = transforms.Compose([ 182 | transforms.RandomResizedCrop(224), 183 | transforms.RandomHorizontalFlip(), 184 | transforms.ToTensor(), 185 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 186 | ]) 187 | transform_test = transforms.Compose([ 188 | transforms.Resize(256), 189 | transforms.CenterCrop(224), 190 | transforms.ToTensor(), 191 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 192 | ]) 193 | train_val_dataset_dir = os.path.join(root, "train") 194 | test_dataset_dir = os.path.join(root, "val") 195 | 196 | trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)) 197 | valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)) 198 | 199 | else: 200 | transform_train = transforms.Compose([ 201 | transforms.RandomResizedCrop(224), 202 | transforms.RandomHorizontalFlip(), 203 | transforms.ToTensor(), 204 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 205 | ]) 206 | transform_test = transforms.Compose([ 207 | transforms.Resize(256), 208 | transforms.CenterCrop(224), 209 | transforms.ToTensor(), 210 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 211 | ]) 212 | 213 | train_val_dataset_dir = os.path.join(root, name, "train") 214 | test_dataset_dir = os.path.join(root, name, "test") 215 | 216 | trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)) 217 | valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)) 218 | 219 | elif name.startswith('cifar'): 220 | transform_train = transforms.Compose([ 221 | transforms.RandomCrop(32, padding=4), 222 | transforms.RandomHorizontalFlip(), 223 | transforms.ToTensor(), 224 | transforms.Normalize((0.4914, 0.4822, 0.4465), 225 | (0.2023, 0.1994, 0.2010)), 226 | ]) 227 | transform_test = transforms.Compose([ 228 | transforms.ToTensor(), 229 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 230 | ]) 231 | 232 | if name == 'cifar10': 233 | CIFAR = datasets.CIFAR10 234 | else: 235 | CIFAR = datasets.CIFAR100 236 | 237 | trainset = DatasetWrapper(CIFAR(root, train=True, download=True, transform=transform_train)) 238 | valset = DatasetWrapper(CIFAR(root, train=False, download=True, transform=transform_test)) 239 | else: 240 | raise Exception('Unknown dataset: {}'.format(name)) 241 | 242 | # Sampler 243 | if sample == 'default': 244 | get_train_sampler = lambda d: BatchSampler(RandomSampler(d), kwargs['batch_size'], False) 245 | get_test_sampler = lambda d: BatchSampler(SequentialSampler(d), kwargs['batch_size'], False) 246 | 247 | elif sample == 'pair': 248 | get_train_sampler = lambda d: PairBatchSampler(d, kwargs['batch_size']) 249 | get_test_sampler = lambda d: BatchSampler(SequentialSampler(d), kwargs['batch_size'], False) 250 | 251 | else: 252 | raise Exception('Unknown sampling: {}'.format(sampling)) 253 | 254 | trainloader = DataLoader(trainset, batch_sampler=get_train_sampler(trainset), num_workers=4) 255 | valloader = DataLoader(valset, batch_sampler=get_test_sampler(valset), num_workers=4) 256 | 257 | return trainloader, valloader 258 | 259 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .densenet import * 3 | from .densenet3 import * 4 | 5 | def load_model(name, num_classes=10, pretrained=False, **kwargs): 6 | model_dict = globals() 7 | model = model_dict[name](pretrained=pretrained, num_classes=num_classes, **kwargs) 8 | return model 9 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as cp 6 | from collections import OrderedDict 7 | # from .utils import load_state_dict_from_url 8 | try: 9 | from torch.hub import load_state_dict_from_url 10 | except ImportError: 11 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 12 | 13 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 14 | 15 | model_urls = { 16 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 17 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 18 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 19 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 20 | } 21 | 22 | 23 | def _bn_function_factory(norm, relu, conv): 24 | def bn_function(*inputs): 25 | concated_features = torch.cat(inputs, 1) 26 | bottleneck_output = conv(relu(norm(concated_features))) 27 | return bottleneck_output 28 | 29 | return bn_function 30 | 31 | 32 | class _DenseLayer(nn.Sequential): 33 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): 34 | super(_DenseLayer, self).__init__() 35 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 36 | self.add_module('relu1', nn.ReLU(inplace=True)), 37 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 38 | growth_rate, kernel_size=1, stride=1, 39 | bias=False)), 40 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 41 | self.add_module('relu2', nn.ReLU(inplace=True)), 42 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 43 | kernel_size=3, stride=1, padding=1, 44 | bias=False)), 45 | self.drop_rate = drop_rate 46 | self.memory_efficient = memory_efficient 47 | 48 | def forward(self, *prev_features): 49 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 50 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 51 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 52 | else: 53 | bottleneck_output = bn_function(*prev_features) 54 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 55 | if self.drop_rate > 0: 56 | new_features = F.dropout(new_features, p=self.drop_rate, 57 | training=self.training) 58 | return new_features 59 | 60 | 61 | class _DenseBlock(nn.Module): 62 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False): 63 | super(_DenseBlock, self).__init__() 64 | for i in range(num_layers): 65 | layer = _DenseLayer( 66 | num_input_features + i * growth_rate, 67 | growth_rate=growth_rate, 68 | bn_size=bn_size, 69 | drop_rate=drop_rate, 70 | memory_efficient=memory_efficient, 71 | ) 72 | self.add_module('denselayer%d' % (i + 1), layer) 73 | 74 | def forward(self, init_features): 75 | features = [init_features] 76 | for name, layer in self.named_children(): 77 | new_features = layer(*features) 78 | features.append(new_features) 79 | return torch.cat(features, 1) 80 | 81 | 82 | class _Transition(nn.Sequential): 83 | def __init__(self, num_input_features, num_output_features): 84 | super(_Transition, self).__init__() 85 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 86 | self.add_module('relu', nn.ReLU(inplace=True)) 87 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 88 | kernel_size=1, stride=1, bias=False)) 89 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 90 | 91 | 92 | class DenseNet(nn.Module): 93 | r"""Densenet-BC model class, based on 94 | `"Densely Connected Convolutional Networks" `_ 95 | Args: 96 | growth_rate (int) - how many filters to add each layer (`k` in paper) 97 | block_config (list of 4 ints) - how many layers in each pooling block 98 | num_init_features (int) - the number of filters to learn in the first convolution layer 99 | bn_size (int) - multiplicative factor for number of bottle neck layers 100 | (i.e. bn_size * k features in the bottleneck layer) 101 | drop_rate (float) - dropout rate after each dense layer 102 | num_classes (int) - number of classification classes 103 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 104 | but slower. Default: *False*. See `"paper" `_ 105 | """ 106 | 107 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 108 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False, bias=True): 109 | 110 | super(DenseNet, self).__init__() 111 | 112 | # First convolution 113 | self.features = nn.Sequential(OrderedDict([ 114 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 115 | padding=3, bias=False)), 116 | ('norm0', nn.BatchNorm2d(num_init_features)), 117 | ('relu0', nn.ReLU(inplace=True)), 118 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 119 | ])) 120 | 121 | # Each denseblock 122 | num_features = num_init_features 123 | for i, num_layers in enumerate(block_config): 124 | block = _DenseBlock( 125 | num_layers=num_layers, 126 | num_input_features=num_features, 127 | bn_size=bn_size, 128 | growth_rate=growth_rate, 129 | drop_rate=drop_rate, 130 | memory_efficient=memory_efficient 131 | ) 132 | self.features.add_module('denseblock%d' % (i + 1), block) 133 | num_features = num_features + num_layers * growth_rate 134 | if i != len(block_config) - 1: 135 | trans = _Transition(num_input_features=num_features, 136 | num_output_features=num_features // 2) 137 | self.features.add_module('transition%d' % (i + 1), trans) 138 | num_features = num_features // 2 139 | 140 | # Final batch norm 141 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 142 | 143 | # Linear layer 144 | self.classifier = nn.Linear(num_features, num_classes, bias=bias) 145 | 146 | # Official init from torch repo. 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | nn.init.kaiming_normal_(m.weight) 150 | elif isinstance(m, nn.BatchNorm2d): 151 | nn.init.constant_(m.weight, 1) 152 | nn.init.constant_(m.bias, 0) 153 | elif isinstance(m, nn.Linear): 154 | if bias: 155 | nn.init.constant_(m.bias, 0) 156 | 157 | def forward(self, x): 158 | features = self.features(x) 159 | out = F.relu(features, inplace=True) 160 | out = F.adaptive_avg_pool2d(out, (1, 1)) 161 | out = torch.flatten(out, 1) 162 | out = self.classifier(out) 163 | return out 164 | 165 | 166 | 167 | def _load_state_dict(model, model_url, progress): 168 | # '.'s are no longer allowed in module names, but previous _DenseLayer 169 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 170 | # They are also in the checkpoints in model_urls. This pattern is used 171 | # to find such keys. 172 | pattern = re.compile( 173 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 174 | 175 | state_dict = load_state_dict_from_url(model_url, progress=progress) 176 | for key in list(state_dict.keys()): 177 | res = pattern.match(key) 178 | if res: 179 | new_key = res.group(1) + res.group(2) 180 | state_dict[new_key] = state_dict[key] 181 | del state_dict[key] 182 | model.load_state_dict(state_dict) 183 | 184 | 185 | def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, num_classes, bias, 186 | **kwargs): 187 | model = DenseNet(growth_rate, block_config, num_init_features, num_classes=num_classes, bias=bias, **kwargs) 188 | if pretrained: 189 | _load_state_dict(model, model_urls[arch], progress) 190 | return model 191 | 192 | 193 | def densenet121(pretrained=False, progress=True, num_classes=10, bias=True, **kwargs): 194 | r"""Densenet-121 model from 195 | `"Densely Connected Convolutional Networks" `_ 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | progress (bool): If True, displays a progress bar of the download to stderr 199 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 200 | but slower. Default: *False*. See `"paper" `_ 201 | """ 202 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, num_classes=num_classes, bias=bias, 203 | **kwargs) 204 | 205 | 206 | def densenet161(pretrained=False, progress=True, **kwargs): 207 | r"""Densenet-161 model from 208 | `"Densely Connected Convolutional Networks" `_ 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | progress (bool): If True, displays a progress bar of the download to stderr 212 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 213 | but slower. Default: *False*. See `"paper" `_ 214 | """ 215 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, 216 | **kwargs) 217 | 218 | 219 | def densenet169(pretrained=False, progress=True, **kwargs): 220 | r"""Densenet-169 model from 221 | `"Densely Connected Convolutional Networks" `_ 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | progress (bool): If True, displays a progress bar of the download to stderr 225 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 226 | but slower. Default: *False*. See `"paper" `_ 227 | """ 228 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, 229 | **kwargs) 230 | 231 | 232 | def densenet201(pretrained=False, progress=True, **kwargs): 233 | r"""Densenet-201 model from 234 | `"Densely Connected Convolutional Networks" `_ 235 | Args: 236 | pretrained (bool): If True, returns a model pre-trained on ImageNet 237 | progress (bool): If True, displays a progress bar of the download to stderr 238 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 239 | but slower. Default: *False*. See `"paper" `_ 240 | """ 241 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, 242 | **kwargs) -------------------------------------------------------------------------------- /models/densenet3.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class CIFAR_DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10, bias=True): 38 | super(CIFAR_DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes, bias=bias) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | def CIFAR_DenseNet121(pretrained=False, num_classes=10, bias=True, **kwargs): 87 | return CIFAR_DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, num_classes=num_classes, bias=bias) 88 | 89 | # def DenseNet169(): 90 | # return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | # def DenseNet201(): 93 | # return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | # def DenseNet161(): 96 | # return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | # def densenet_cifar(): 99 | # return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | # def test(): 102 | # net = densenet_cifar() 103 | # x = torch.randn(1,3,32,32) 104 | # y = net(x) 105 | # print(y) 106 | 107 | # test() -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'CIFAR_ResNet', 'CIFAR_ResNet18', 'CIFAR_ResNet34', 'CIFAR_ResNet10'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, groups=groups, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 33 | base_width=64, norm_layer=None): 34 | super(BasicBlock, self).__init__() 35 | if norm_layer is None: 36 | norm_layer = nn.BatchNorm2d 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = norm_layer(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = norm_layer(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | identity = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 71 | base_width=64, norm_layer=None): 72 | super(Bottleneck, self).__init__() 73 | if norm_layer is None: 74 | norm_layer = nn.BatchNorm2d 75 | width = int(planes * (base_width / 64.)) * groups 76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv1x1(inplanes, width) 78 | self.bn1 = norm_layer(width) 79 | self.conv2 = conv3x3(width, width, stride, groups) 80 | self.bn2 = norm_layer(width) 81 | self.conv3 = conv1x1(width, planes * self.expansion) 82 | self.bn3 = norm_layer(planes * self.expansion) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | identity = self.downsample(x) 103 | 104 | out += identity 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | class PreActBlock(nn.Module): 110 | '''Pre-activation version of the BasicBlock.''' 111 | expansion = 1 112 | 113 | def __init__(self, in_planes, planes, stride=1): 114 | super(PreActBlock, self).__init__() 115 | self.bn1 = nn.BatchNorm2d(in_planes) 116 | self.conv1 = conv3x3(in_planes, planes, stride) 117 | self.bn2 = nn.BatchNorm2d(planes) 118 | self.conv2 = conv3x3(planes, planes) 119 | 120 | self.shortcut = nn.Sequential() 121 | if stride != 1 or in_planes != self.expansion*planes: 122 | self.shortcut = nn.Sequential( 123 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 124 | ) 125 | 126 | def forward(self, x): 127 | out = F.relu(self.bn1(x)) 128 | shortcut = self.shortcut(out) 129 | out = self.conv1(out) 130 | out = self.conv2(F.relu(self.bn2(out))) 131 | out += shortcut 132 | return out 133 | 134 | class ResNet(nn.Module): 135 | 136 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 137 | groups=1, width_per_group=64, norm_layer=None): 138 | super(ResNet, self).__init__() 139 | if norm_layer is None: 140 | norm_layer = nn.BatchNorm2d 141 | 142 | self.inplanes = 64 143 | self.groups = groups 144 | self.base_width = width_per_group 145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = norm_layer(self.inplanes) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 150 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 154 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 155 | self.fc = nn.Linear(512 * block.expansion, num_classes) 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 159 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 160 | nn.init.constant_(m.weight, 1) 161 | nn.init.constant_(m.bias, 0) 162 | 163 | # Zero-initialize the last BN in each residual branch, 164 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 165 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 166 | if zero_init_residual: 167 | for m in self.modules(): 168 | if isinstance(m, Bottleneck): 169 | nn.init.constant_(m.bn3.weight, 0) 170 | elif isinstance(m, BasicBlock): 171 | nn.init.constant_(m.bn2.weight, 0) 172 | 173 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): 174 | if norm_layer is None: 175 | norm_layer = nn.BatchNorm2d 176 | downsample = None 177 | if stride != 1 or self.inplanes != planes * block.expansion: 178 | downsample = nn.Sequential( 179 | conv1x1(self.inplanes, planes * block.expansion, stride), 180 | norm_layer(planes * block.expansion), 181 | ) 182 | 183 | layers = [] 184 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 185 | self.base_width, norm_layer)) 186 | self.inplanes = planes * block.expansion 187 | for _ in range(1, blocks): 188 | layers.append(block(self.inplanes, planes, groups=self.groups, 189 | base_width=self.base_width, norm_layer=norm_layer)) 190 | 191 | return nn.Sequential(*layers) 192 | 193 | def forward(self, x): 194 | x = self.conv1(x) 195 | x = self.bn1(x) 196 | x = self.relu(x) 197 | x = self.maxpool(x) 198 | 199 | x = self.layer1(x) 200 | x = self.layer2(x) 201 | x = self.layer3(x) 202 | x = self.layer4(x) 203 | 204 | x = self.avgpool(x) 205 | x = x.view(x.size(0), -1) 206 | x = self.fc(x) 207 | 208 | return x 209 | 210 | class CIFAR_ResNet(nn.Module): 211 | def __init__(self, block, num_blocks, num_classes=10, bias=True): 212 | super(CIFAR_ResNet, self).__init__() 213 | self.in_planes = 64 214 | self.conv1 = conv3x3(3,64) 215 | self.bn1 = nn.BatchNorm2d(64) 216 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 217 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 218 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 219 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 220 | self.linear = nn.Linear(512*block.expansion, num_classes, bias=bias) 221 | 222 | 223 | def _make_layer(self, block, planes, num_blocks, stride): 224 | strides = [stride] + [1]*(num_blocks-1) 225 | layers = [] 226 | for stride in strides: 227 | layers.append(block(self.in_planes, planes, stride)) 228 | self.in_planes = planes * block.expansion 229 | return nn.Sequential(*layers) 230 | 231 | def forward(self, x, lin=0, lout=5): 232 | out = x 233 | out = self.conv1(out) 234 | out = self.bn1(out) 235 | out = F.relu(out) 236 | out1 = self.layer1(out) 237 | out2 = self.layer2(out1) 238 | out3 = self.layer3(out2) 239 | out = self.layer4(out3) 240 | out = F.avg_pool2d(out, 4) 241 | out4 = out.view(out.size(0), -1) 242 | out = self.linear(out4) 243 | 244 | return out 245 | 246 | 247 | def CIFAR_ResNet10(pretrained=False, **kwargs): 248 | return CIFAR_ResNet(PreActBlock, [1,1,1,1], **kwargs) 249 | 250 | def CIFAR_ResNet18(pretrained=False, **kwargs): 251 | return CIFAR_ResNet(PreActBlock, [2,2,2,2], **kwargs) 252 | 253 | def CIFAR_ResNet34(pretrained=False, **kwargs): 254 | return CIFAR_ResNet(PreActBlock, [3,4,6,3], **kwargs) 255 | 256 | def resnet10(pretrained=False, **kwargs): 257 | """Constructs a ResNet-10 model. 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | """ 261 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 262 | return model 263 | 264 | def resnet18(pretrained=False, **kwargs): 265 | """Constructs a ResNet-18 model. 266 | Args: 267 | pretrained (bool): If True, returns a model pre-trained on ImageNet 268 | """ 269 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 270 | if pretrained: 271 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 272 | return model 273 | 274 | 275 | def my_resnet34(pretrained=False, **kwargs): 276 | model = my_ResNet(BasicBlock, [3, 4, 6, 3, 3], **kwargs) 277 | return model 278 | 279 | def resnet34(pretrained=False, **kwargs): 280 | """Constructs a ResNet-34 model. 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | """ 284 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 285 | if pretrained: 286 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 287 | return model 288 | 289 | 290 | def resnet50(pretrained=False, **kwargs): 291 | """Constructs a ResNet-50 model. 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained on ImageNet 294 | """ 295 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 296 | if pretrained: 297 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 298 | return model 299 | 300 | 301 | def resnet101(pretrained=False, **kwargs): 302 | """Constructs a ResNet-101 model. 303 | Args: 304 | pretrained (bool): If True, returns a model pre-trained on ImageNet 305 | """ 306 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 307 | if pretrained: 308 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 309 | return model 310 | 311 | 312 | def resnet152(pretrained=False, **kwargs): 313 | """Constructs a ResNet-152 model. 314 | Args: 315 | pretrained (bool): If True, returns a model pre-trained on ImageNet 316 | """ 317 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 318 | if pretrained: 319 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 320 | return model 321 | 322 | 323 | def resnext50_32x4d(pretrained=False, **kwargs): 324 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs) 325 | # if pretrained: 326 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d'])) 327 | return model 328 | 329 | 330 | def resnext101_32x8d(pretrained=False, **kwargs): 331 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) 332 | # if pretrained: 333 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'])) 334 | return model 335 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import csv 5 | import os, logging 6 | 7 | import numpy as np 8 | import torch 9 | from torch.autograd import Variable, grad 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import torchvision.transforms as transforms 15 | 16 | import models 17 | from utils import progress_bar, set_logging_defaults 18 | from datasets import load_dataset 19 | 20 | parser = argparse.ArgumentParser(description='CS-KD Training') 21 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 22 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 23 | parser.add_argument('--model', default="CIFAR_ResNet18", type=str, help='model type (32x32: CIFAR_ResNet18, CIFAR_DenseNet121, 224x224: resnet18, densenet121)') 24 | parser.add_argument('--name', default='0', type=str, help='name of run') 25 | parser.add_argument('--batch-size', default=128, type=int, help='batch size') 26 | parser.add_argument('--epoch', default=200, type=int, help='total epochs to run') 27 | parser.add_argument('--decay', default=1e-4, type=float, help='weight decay') 28 | parser.add_argument('--ngpu', default=1, type=int, help='number of gpu') 29 | parser.add_argument('--sgpu', default=0, type=int, help='gpu index (start)') 30 | parser.add_argument('--dataset', default='cifar100', type=str, help='the name for dataset cifar100 | tinyimagenet | CUB200 | STANFORD120 | MIT67') 31 | parser.add_argument('--dataroot', default='~/data/', type=str, help='data directory') 32 | parser.add_argument('--saveroot', default='./results', type=str, help='save directory') 33 | parser.add_argument('--cls', '-cls', action='store_true', help='adding cls loss') 34 | parser.add_argument('--temp', default=4.0, type=float, help='temperature scaling') 35 | parser.add_argument('--lamda', default=1.0, type=float, help='cls loss weight ratio') 36 | 37 | args = parser.parse_args() 38 | use_cuda = torch.cuda.is_available() 39 | 40 | best_val = 0 # best validation accuracy 41 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 42 | 43 | cudnn.benchmark = True 44 | 45 | # Data 46 | print('==> Preparing dataset: {}'.format(args.dataset)) 47 | if not args.cls: 48 | trainloader, valloader = load_dataset(args.dataset, args.dataroot, batch_size=args.batch_size) 49 | else: 50 | trainloader, valloader = load_dataset(args.dataset, args.dataroot, 'pair', batch_size=args.batch_size) 51 | 52 | 53 | num_class = trainloader.dataset.num_classes 54 | print('Number of train dataset: ' ,len(trainloader.dataset)) 55 | print('Number of validation dataset: ' ,len(valloader.dataset)) 56 | 57 | # Model 58 | print('==> Building model: {}'.format(args.model)) 59 | 60 | net = models.load_model(args.model, num_class) 61 | # print(net) 62 | 63 | if use_cuda: 64 | torch.cuda.set_device(args.sgpu) 65 | net.cuda() 66 | print(torch.cuda.device_count()) 67 | print('Using CUDA..') 68 | 69 | if args.ngpu > 1: 70 | net = torch.nn.DataParallel(net, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu))) 71 | 72 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay) 73 | 74 | logdir = os.path.join(args.saveroot, args.dataset, args.model, args.name) 75 | set_logging_defaults(logdir, args) 76 | logger = logging.getLogger('main') 77 | logname = os.path.join(logdir, 'log.csv') 78 | 79 | 80 | # Resume 81 | if args.resume: 82 | # Load checkpoint. 83 | print('==> Resuming from checkpoint..') 84 | checkpoint = torch.load(os.path.join(logdir, 'ckpt.t7')) 85 | net.load_state_dict(checkpoint['net']) 86 | optimizer.load_state_dict(checkpoint['optimizer']) 87 | best_acc = checkpoint['acc'] 88 | start_epoch = checkpoint['epoch'] + 1 89 | rng_state = checkpoint['rng_state'] 90 | torch.set_rng_state(rng_state) 91 | 92 | criterion = nn.CrossEntropyLoss() 93 | 94 | 95 | class KDLoss(nn.Module): 96 | def __init__(self, temp_factor): 97 | super(KDLoss, self).__init__() 98 | self.temp_factor = temp_factor 99 | self.kl_div = nn.KLDivLoss(reduction="sum") 100 | 101 | def forward(self, input, target): 102 | log_p = torch.log_softmax(input/self.temp_factor, dim=1) 103 | q = torch.softmax(target/self.temp_factor, dim=1) 104 | loss = self.kl_div(log_p, q)*(self.temp_factor**2)/input.size(0) 105 | return loss 106 | 107 | kdloss = KDLoss(args.temp) 108 | 109 | def train(epoch): 110 | print('\nEpoch: %d' % epoch) 111 | net.train() 112 | train_loss = 0 113 | correct = 0 114 | total = 0 115 | train_cls_loss = 0 116 | for batch_idx, (inputs, targets) in enumerate(trainloader): 117 | if use_cuda: 118 | inputs, targets = inputs.cuda(), targets.cuda() 119 | 120 | batch_size = inputs.size(0) 121 | 122 | if not args.cls: 123 | outputs = net(inputs) 124 | loss = torch.mean(criterion(outputs, targets)) 125 | train_loss += loss.item() 126 | 127 | _, predicted = torch.max(outputs, 1) 128 | total += targets.size(0) 129 | correct += predicted.eq(targets.data).sum().float().cpu() 130 | else: 131 | targets_ = targets[:batch_size//2] 132 | outputs = net(inputs[:batch_size//2]) 133 | loss = torch.mean(criterion(outputs, targets_)) 134 | train_loss += loss.item() 135 | 136 | with torch.no_grad(): 137 | outputs_cls = net(inputs[batch_size//2:]) 138 | cls_loss = kdloss(outputs, outputs_cls.detach()) 139 | loss += args.lamda * cls_loss 140 | train_cls_loss += cls_loss.item() 141 | 142 | _, predicted = torch.max(outputs, 1) 143 | total += targets_.size(0) 144 | correct += predicted.eq(targets_.data).sum().float().cpu() 145 | 146 | optimizer.zero_grad() 147 | loss.backward() 148 | optimizer.step() 149 | progress_bar(batch_idx, len(trainloader), 150 | 'Loss: %.3f | Acc: %.3f%% (%d/%d) | Cls: %.3f ' 151 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total, train_cls_loss/(batch_idx+1))) 152 | 153 | logger = logging.getLogger('train') 154 | logger.info('[Epoch {}] [Loss {:.3f}] [cls {:.3f}] [Acc {:.3f}]'.format( 155 | epoch, 156 | train_loss/(batch_idx+1), 157 | train_cls_loss/(batch_idx+1), 158 | 100.*correct/total)) 159 | 160 | return train_loss/batch_idx, 100.*correct/total, train_cls_loss/batch_idx 161 | 162 | def val(epoch): 163 | global best_val 164 | net.eval() 165 | val_loss = 0.0 166 | correct = 0.0 167 | total = 0.0 168 | 169 | # Define a data loader for evaluating 170 | loader = valloader 171 | 172 | with torch.no_grad(): 173 | for batch_idx, (inputs, targets) in enumerate(loader): 174 | if use_cuda: 175 | inputs, targets = inputs.cuda(), targets.cuda() 176 | 177 | outputs = net(inputs) 178 | loss = torch.mean(criterion(outputs, targets)) 179 | 180 | val_loss += loss.item() 181 | _, predicted = torch.max(outputs, 1) 182 | total += targets.size(0) 183 | correct += predicted.eq(targets.data).cpu().sum().float() 184 | 185 | progress_bar(batch_idx, len(loader), 186 | 'Loss: %.3f | Acc: %.3f%% (%d/%d) ' 187 | % (val_loss/(batch_idx+1), 100.*correct/total, correct, total)) 188 | 189 | acc = 100.*correct/total 190 | logger = logging.getLogger('val') 191 | logger.info('[Epoch {}] [Loss {:.3f}] [Acc {:.3f}]'.format( 192 | epoch, 193 | val_loss/(batch_idx+1), 194 | acc)) 195 | 196 | if acc > best_val: 197 | best_val = acc 198 | checkpoint(acc, epoch) 199 | 200 | return (val_loss/(batch_idx+1), acc) 201 | 202 | 203 | def checkpoint(acc, epoch): 204 | # Save checkpoint. 205 | print('Saving..') 206 | state = { 207 | 'net': net.state_dict(), 208 | 'optimizer': optimizer.state_dict(), 209 | 'acc': acc, 210 | 'epoch': epoch, 211 | 'rng_state': torch.get_rng_state() 212 | } 213 | torch.save(state, os.path.join(logdir, 'ckpt.t7')) 214 | 215 | 216 | def adjust_learning_rate(optimizer, epoch): 217 | """decrease the learning rate at 100 and 150 epoch""" 218 | lr = args.lr 219 | if epoch >= 0.5 * args.epoch: 220 | lr /= 10 221 | if epoch >= 0.75 * args.epoch: 222 | lr /= 10 223 | for param_group in optimizer.param_groups: 224 | param_group['lr'] = lr 225 | 226 | 227 | # Logs 228 | for epoch in range(start_epoch, args.epoch): 229 | train_loss, train_acc, train_cls_loss = train(epoch) 230 | val_loss, val_acc = val(epoch) 231 | adjust_learning_rate(optimizer, epoch) 232 | 233 | print("Best Accuracy : {}".format(best_val)) 234 | logger = logging.getLogger('best') 235 | logger.info('[Acc {:.3f}]'.format(best_val)) 236 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, logging 2 | import sys 3 | import time 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | 10 | 11 | def set_logging_defaults(logdir, args): 12 | if os.path.isdir(logdir): 13 | res = input('"{}" exists. Overwrite [Y/n]? '.format(logdir)) 14 | if res != 'Y': 15 | raise Exception('"{}" exists.'.format(logdir)) 16 | else: 17 | os.makedirs(logdir) 18 | 19 | # set basic configuration for logging 20 | logging.basicConfig(format="[%(asctime)s] [%(name)s] %(message)s", 21 | level=logging.INFO, 22 | handlers=[logging.FileHandler(os.path.join(logdir, 'log.txt')), 23 | logging.StreamHandler(os.sys.stdout)]) 24 | 25 | # log cmdline argumetns 26 | logger = logging.getLogger('main') 27 | logger.info(' '.join(os.sys.argv)) 28 | logger.info(args) 29 | 30 | _, term_width = os.popen('stty size', 'r').read().split() 31 | term_width = int(term_width) 32 | 33 | TOTAL_BAR_LENGTH = 86. 34 | last_time = time.time() 35 | begin_time = last_time 36 | 37 | def progress_bar(current, total, msg=None): 38 | global last_time, begin_time 39 | if current == 0: 40 | begin_time = time.time() # Reset for new bar. 41 | 42 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 43 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 44 | 45 | sys.stdout.write(' [') 46 | for i in range(cur_len): 47 | sys.stdout.write('=') 48 | sys.stdout.write('>') 49 | for i in range(rest_len): 50 | sys.stdout.write('.') 51 | sys.stdout.write(']') 52 | 53 | cur_time = time.time() 54 | step_time = cur_time - last_time 55 | last_time = cur_time 56 | tot_time = cur_time - begin_time 57 | 58 | L = [] 59 | L.append(' Step: %s' % format_time(step_time)) 60 | L.append(' | Tot: %s' % format_time(tot_time)) 61 | if msg: 62 | L.append(' | ' + msg) 63 | 64 | msg = ''.join(L) 65 | sys.stdout.write(msg) 66 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 67 | sys.stdout.write(' ') 68 | 69 | # Go back to the center of the bar. 70 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): 71 | sys.stdout.write('\b') 72 | sys.stdout.write(' %d/%d ' % (current+1, total)) 73 | 74 | if current < total-1: 75 | sys.stdout.write('\r') 76 | else: 77 | sys.stdout.write('\n') 78 | sys.stdout.flush() 79 | 80 | def format_time(seconds): 81 | days = int(seconds / 3600/24) 82 | seconds = seconds - days*3600*24 83 | hours = int(seconds / 3600) 84 | seconds = seconds - hours*3600 85 | minutes = int(seconds / 60) 86 | seconds = seconds - minutes*60 87 | secondsf = int(seconds) 88 | seconds = seconds - secondsf 89 | millis = int(seconds*1000) 90 | 91 | f = '' 92 | i = 1 93 | if days > 0: 94 | f += str(days) + 'D' 95 | i += 1 96 | if hours > 0 and i <= 2: 97 | f += str(hours) + 'h' 98 | i += 1 99 | if minutes > 0 and i <= 2: 100 | f += str(minutes) + 'm' 101 | i += 1 102 | if secondsf > 0 and i <= 2: 103 | f += str(secondsf) + 's' 104 | i += 1 105 | if millis > 0 and i <= 2: 106 | f += str(millis) + 'ms' 107 | i += 1 108 | if f == '': 109 | f = '0ms' 110 | return f 111 | --------------------------------------------------------------------------------