├── __init__.py ├── requirements.txt ├── review-mechanism.jpg ├── params.py ├── test.py ├── teachers.py ├── students.py ├── data.py ├── utils ├── misc.py └── resnets_for_cifar.py ├── README.md ├── train.py ├── experimental ├── hcl_experiments.py ├── abf_experiments.py └── table7_experiments.py └── framework.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | torch==1.10.2 3 | torchvision==0.11.3 4 | -------------------------------------------------------------------------------- /review-mechanism.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevPranjal/reproduction-review-kd/HEAD/review-mechanism.jpg -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | params = { 2 | "dataset": "cifar10", 3 | "student": "resnet20", 4 | "teacher": "resnet56", 5 | "teacher_weight_path": f"./pretrained/resnet56.pt", 6 | 7 | "batch_size": 64, 8 | "num_epochs": 20, 9 | "lr": 0.1, 10 | "lr_decay_steps": [12, 17], 11 | "lr_decay_rate": 0.1, 12 | "weight_decay": 5e-4, 13 | "args": 0, 14 | 15 | "kd_loss_weight": 0.6, 16 | 17 | "seed": 0, 18 | } 19 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def test(net, test_loader): 5 | net.eval() 6 | correct_preds = 0. 7 | total_images = 0. 8 | for images, labels in test_loader: 9 | images, labels = images.cuda(), labels.cuda() 10 | 11 | with torch.no_grad(): 12 | features, preds = net(images) 13 | 14 | preds = torch.max(preds.data, 1)[1] 15 | total_images += labels.size(0) 16 | correct_preds += (preds == labels).sum().item() 17 | 18 | test_acc = correct_preds / total_images 19 | net.train() 20 | return test_acc 21 | -------------------------------------------------------------------------------- /teachers.py: -------------------------------------------------------------------------------- 1 | from utils.resnets_for_cifar import ResNet 2 | 3 | 4 | def resnet44(**kwargs): 5 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs) 6 | 7 | 8 | def resnet56(**kwargs): 9 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs) 10 | 11 | 12 | def resnet110(**kwargs): 13 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs) 14 | 15 | 16 | def get_teacher(name, **kwargs): 17 | if name == 'resnet44': 18 | return resnet44(**kwargs) 19 | elif name == 'resnet56': 20 | return resnet56(**kwargs) 21 | elif name == 'resnet110': 22 | return resnet110(**kwargs) 23 | -------------------------------------------------------------------------------- /students.py: -------------------------------------------------------------------------------- 1 | from utils.resnets_for_cifar import ResNet 2 | 3 | 4 | def resnet8(**kwargs): 5 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs) 6 | 7 | 8 | def resnet14(**kwargs): 9 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs) 10 | 11 | 12 | def resnet20(**kwargs): 13 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs) 14 | 15 | 16 | def resnet32(**kwargs): 17 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs) 18 | 19 | 20 | def get_student(name, **kwargs): 21 | if name == 'resnet8': 22 | return resnet8(**kwargs) 23 | elif name == 'resnet14': 24 | return resnet14(**kwargs) 25 | elif name == 'resnet20': 26 | return resnet20(**kwargs) 27 | elif name == 'resnet32': 28 | return resnet32(**kwargs) 29 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | 5 | normalize = transforms.Normalize( 6 | mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 7 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]] 8 | ) 9 | 10 | train_transform = transforms.Compose([ 11 | transforms.RandomCrop(32, padding=4), 12 | transforms.RandomHorizontalFlip(), 13 | transforms.ToTensor(), 14 | normalize 15 | ]) 16 | 17 | test_transform = transforms.Compose([ 18 | transforms.ToTensor(), 19 | normalize 20 | ]) 21 | 22 | 23 | def get_dataloaders(dataset, batch_size): 24 | if dataset == 'cifar10': 25 | train_dataset = datasets.CIFAR10( 26 | root='data/', train=True, transform=train_transform, download=True) 27 | test_dataset = datasets.CIFAR10( 28 | root='data/', train=False, transform=test_transform, download=True) 29 | elif dataset == 'cifar100': 30 | train_dataset = datasets.CIFAR100( 31 | root='data/', train=True, transform=train_transform, download=True) 32 | test_dataset = datasets.CIFAR100( 33 | root='data/', train=False, transform=test_transform, download=True) 34 | 35 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, 36 | shuffle=True, pin_memory=True, num_workers=2) 37 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, 38 | shuffle=False, pin_memory=True, num_workers=1) 39 | 40 | return train_loader, test_loader 41 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value 3 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 4 | """ 5 | 6 | def __init__(self): 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count 20 | 21 | 22 | def format_time(seconds): 23 | days = int(seconds / 3600/24) 24 | seconds = seconds - days*3600*24 25 | hours = int(seconds / 3600) 26 | seconds = seconds - hours*3600 27 | minutes = int(seconds / 60) 28 | seconds = seconds - minutes*60 29 | secondsf = int(seconds) 30 | seconds = seconds - secondsf 31 | millis = int(seconds*1000) 32 | 33 | f = '' 34 | i = 1 35 | if days > 0: 36 | f += str(days) + 'D' 37 | i += 1 38 | if hours > 0 and i <= 2: 39 | f += str(hours) + 'h' 40 | i += 1 41 | if minutes > 0 and i <= 2: 42 | f += str(minutes) + 'm' 43 | i += 1 44 | if secondsf > 0 and i <= 2: 45 | f += str(secondsf) + 's' 46 | i += 1 47 | if millis > 0 and i <= 2: 48 | f += str(millis) + 'ms' 49 | i += 1 50 | if f == '': 51 | f = '0ms' 52 | return f 53 | 54 | 55 | class Logger(): 56 | def __init__(self, params, filename='log.txt'): 57 | self.filename = filename 58 | self.file = open(filename, 'w') 59 | # Write model configuration at top of file 60 | for key, value in params.items(): 61 | self.file.write(f'{key}: {value}\n') 62 | self.file.flush() 63 | 64 | def writerow(self, row): 65 | for k in row: 66 | self.file.write(k+': '+row[k]+' ') 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def close(self): 71 | self.file.close() 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ML Reproducibility Challenge 2021: Fall Edition 2 | 3 | ### Submission for [Distilling Knowledge via Knowledge Review](https://arxiv.org/abs/2104.09044) published in CVPR 2021 4 | 5 |
review-mechanism
6 | 7 | --- 8 | 9 | This effort aims to reproduce the results of experiments and analyse the robustness of the review framework for knowledge distillation introduced in the original paper. We verify the improvement in test accuracy of consistently across student models as reported and study the effectiveness of the novel modules introduced by the authors by conducting ablation studies and new experiments. 10 | 11 | --- 12 | 13 | ### Setting up environment 14 | 15 | ```bash 16 | conda create -n reviewkd python=3.8 17 | conda activate reviewkd 18 | 19 | git clone https://github.com/DevPranjal/ml-repro-2021.git 20 | cd ml-repro-2021 21 | 22 | pip install requirements.txt 23 | ``` 24 | 25 | ### Training baseline teachers 26 | 27 | To train the teacher model, we use the code written by the authors as follows: 28 | 29 | ```bash 30 | git clone https://github.com/dvlab-research/ReviewKD 31 | cd ReviewKD/CIFAR100 32 | python train.py --model resnet56 33 | ``` 34 | 35 | ### Training student via review mechanism 36 | 37 | To train the student model, we have designed `params.py` for all the settings that can be tuned. After setting the desired values for each key, run the following within the `ml-repro-2021` directory 38 | 39 | ```bash 40 | python train.py 41 | ``` 42 | 43 | ### Performing ablation studies and experiments 44 | 45 | The ablation studies and experiments have been organized and implemented in `experimental/`. To execute any of them, run the following command: 46 | 47 | ```bash 48 | cd experimental 49 | python table7_experiments.py 50 | ``` 51 | 52 | ### Reproduction Results 53 | 54 | #### Classification results when student and teacher have architectures of the same style 55 | 56 | | Student | ResNet20 | ResNet32 | ReNet8x4 | WRN16-2 | WRN40-1 | 57 | | --------------------- | -------- | --------- | --------- | ------- | ------- | 58 | | Teacher | ResNet56 | ResNet110 | ReNet32x4 | WRN40-2 | WRN40-2 | 59 | | Review KD (Paper) | 71.89 | 73.89 | 75.63 | 76.12 | 75.09 | 60 | | Review KD (Ours) | 71.79 | 73.61 | 76.02 | 76.27 | 75.21 | 61 | | Review KD Loss Weight | 0.7 | 1.0 | 5.0 | 5.0 | 5.0 | 62 | 63 | #### Classification results when student and teacher have architectures of different styles 64 | 65 | | Student | ShuffleNetV1 | ShuffleNetV1 | ShuffleNetV2 | 66 | | --------------------- | ------------ | ------------ | ------------ | 67 | | Teacher | ResNet32x4 | WRN40-2 | ReNet32x4 | 68 | | Review KD (Paper) | 77.45 | 77.14 | 77.78 | 69 | | Review KD (Ours) | 76.94 | 77.44 | 77.86 | 70 | | Review KD Loss Weight | 5.0 | 5.0 | 8.0 | 71 | 72 | #### Adding architectural components one by one 73 | 74 | RM - Review Mechanism 75 | 76 | RLF - Residual Learning Framework 77 | 78 | ABF - Attention Based Fusion 79 | 80 | HCL - Hierarchical Context Loss 81 | 82 | | RM | RLF | ABF | HCL | Test Accuracy | 83 | | --- | --- | --- | --- | ------------- | 84 | | | | | | 69.50 | 85 | | Y | | | | 69.53 | 86 | | Y | Y | | | 69.92 | 87 | | Y | Y | Y | | 71.28 | 88 | | Y | Y | | Y | 71.51 | 89 | | Y | Y | Y | Y | 71.79 | 90 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.backends.cudnn as cudnn 4 | import numpy as np 5 | import time 6 | 7 | from params import params 8 | from data import get_dataloaders 9 | from teachers import get_teacher 10 | from students import get_student 11 | from framework import RLF_for_Resnet, ABF, hcl 12 | from utils.misc import AverageMeter, format_time, Logger 13 | from test import test 14 | 15 | 16 | def train(params, hcl, abf, RLF, log_file_suffix=''): 17 | cudnn.deterministic = True 18 | cudnn.benchmark = False 19 | if params["seed"] == 0: 20 | params["seed"] = np.random.randint(1000) 21 | torch.manual_seed(params["seed"]) 22 | np.random.seed(params["seed"]) 23 | torch.cuda.manual_seed(params["seed"]) 24 | 25 | 26 | train_loader, test_loader = get_dataloaders(params["dataset"], params["batch_size"]) 27 | if params["dataset"] == 'cifar10': 28 | num_classes = 10 29 | elif params["dataset"] == 'cifar100': 30 | num_classes = 100 31 | 32 | 33 | teacher = get_teacher(params["teacher"], num_classes=num_classes) 34 | student = get_student(params["student"], num_classes=num_classes) 35 | 36 | 37 | # build the framework for student to be trained 38 | rlf = RLF(student, abf_to_use=abf) # rlf => residual learning framework 39 | # load teacher weights from pretrained model 40 | weight = torch.load(params["teacher_weight_path"]) 41 | teacher.load_state_dict(weight) 42 | for p in teacher.parameters(): 43 | p.requires_grad = False 44 | teacher.to(torch.device('cuda:0')) 45 | 46 | 47 | base_loss = nn.CrossEntropyLoss().cuda() 48 | optimizer = torch.optim.SGD( 49 | rlf.parameters(), 50 | lr=params["lr"], 51 | momentum=0.9, 52 | nesterov=True, 53 | weight_decay=params["weight_decay"] 54 | ) 55 | 56 | 57 | train_log_file = f"logs/{params['dataset'] + '_' + params['student'] + '_' + params['teacher'] + '_' + log_file_suffix}" 58 | logger = Logger(params=params, filename=train_log_file+'.txt') 59 | best_accuracy = 0.0 60 | best_model = rlf 61 | 62 | 63 | start_time = time.time() 64 | print("starting training with the following params:") 65 | print(params) 66 | print() 67 | 68 | 69 | for epoch in range(params["num_epochs"]): 70 | loss_avg = { 71 | 'kd_loss': AverageMeter(), 72 | 'base_loss': AverageMeter() 73 | } 74 | correct_preds = 0.0 75 | total_images = 0.0 76 | 77 | for i, (X, y) in enumerate(train_loader): 78 | X, y = X.cuda(), y.cuda() 79 | 80 | losses = {"kd_loss": 0, "base_loss": 0} 81 | 82 | # getting student and teacher features 83 | student_features, student_preds = rlf(X) 84 | teacher_features, teacher_preds = teacher(X, is_feat=True, preact=True) 85 | 86 | teacher_features = teacher_features[1:] 87 | 88 | # calculating review kd loss 89 | for sf, tf in zip(student_features, teacher_features): 90 | losses['kd_loss'] += hcl(sf, tf) 91 | 92 | # calculating cross entropy loss 93 | losses['base_loss'] = base_loss(student_preds, y) 94 | 95 | loss = losses['kd_loss'] * params['kd_loss_weight'] 96 | loss += losses['base_loss'] 97 | 98 | optimizer.zero_grad() 99 | loss.backward() 100 | optimizer.step() 101 | 102 | for key in losses: 103 | loss_avg[key].update(losses[key]) 104 | 105 | # calculate running average of accuracy 106 | student_preds = torch.max(student_preds.data, 1)[1] 107 | total_images += y.size(0) 108 | correct_preds += (student_preds == y.data).sum().item() 109 | train_accuracy = correct_preds / total_images 110 | 111 | # calculating test accuracy and storing best results 112 | test_accuracy = test(rlf, test_loader) 113 | if test_accuracy > best_accuracy: 114 | best_accuracy = test_accuracy 115 | best_model = rlf 116 | 117 | # decaying lr at scheduled steps 118 | if epoch in params['lr_decay_steps']: 119 | params['lr'] *= params["lr_decay_rate"] 120 | for param_group in optimizer.param_groups: 121 | param_group['lr'] = params['lr'] 122 | 123 | # logging results 124 | loss_avg = {k: loss_avg[k].val for k in loss_avg} 125 | log_row = { 126 | 'epoch': str(epoch), 127 | 'train_acc': '%.2f' % (train_accuracy*100), 128 | 'test_acc': '%.2f' % (test_accuracy*100), 129 | 'best_acc': '%.2f' % (best_accuracy*100), 130 | 'lr': '%.5f' % (params['lr']), 131 | 'loss': '%.5f' % (sum(loss_avg.values())), 132 | 'kd_loss': '%.5f' % loss_avg['kd_loss'], 133 | 'base_loss': '%.5f' % loss_avg['base_loss'], 134 | 'time': format_time(time.time()-start_time), 135 | 'eta': format_time((time.time()-start_time)/(epoch+1)*(params["num_epochs"]-epoch-1)), 136 | } 137 | print(log_row) 138 | logger.writerow(log_row) 139 | 140 | torch.save(best_model.state_dict(), 'pretrained/' + train_log_file + '.pt') 141 | logger.close() 142 | 143 | 144 | if __name__ == '__main__': 145 | train(params, hcl, ABF, RLF_for_Resnet) -------------------------------------------------------------------------------- /experimental/hcl_experiments.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '..') 3 | 4 | 5 | from train import train 6 | from params import params 7 | from hcl_experiments import * 8 | from abf_experiments import * 9 | from framework import ABF, hcl, RLF_for_Resnet 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | 14 | 15 | def hcl_level_1(student_features, teacher_features): 16 | loss = 0.0 17 | n, c, h, w = student_features.shape 18 | 19 | levels = [h, 2, 1] 20 | level_weight = [1.0, 0.5, 0.25] 21 | total_weight = sum(level_weight) 22 | 23 | for lvl, lvl_weight in zip(levels, level_weight): 24 | if lvl > h: 25 | continue 26 | 27 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl)) 28 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl)) 29 | 30 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight 31 | loss += lvl_loss 32 | 33 | return loss / total_weight 34 | 35 | 36 | def hcl_level_2(student_features, teacher_features): 37 | loss = 0.0 38 | n, c, h, w = student_features.shape 39 | 40 | levels = [4, 1] 41 | level_weight = [1.0, 0.5] 42 | total_weight = sum(level_weight) 43 | 44 | for lvl, lvl_weight in zip(levels, level_weight): 45 | if lvl > h: 46 | continue 47 | 48 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl)) 49 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl)) 50 | 51 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight 52 | loss += lvl_loss 53 | 54 | return loss / total_weight 55 | 56 | 57 | def hcl_level_3(student_features, teacher_features): 58 | loss = 0.0 59 | n, c, h, w = student_features.shape 60 | 61 | levels = [h, h//2, h//4] 62 | level_weight = [1.0, 0.5, 0.25] 63 | total_weight = sum(level_weight) 64 | 65 | for lvl, lvl_weight in zip(levels, level_weight): 66 | if lvl > h: 67 | continue 68 | 69 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl)) 70 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl)) 71 | 72 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight 73 | loss += lvl_loss 74 | 75 | return loss / total_weight 76 | 77 | 78 | def hcl_level_4(student_features, teacher_features): 79 | loss = 0.0 80 | n, c, h, w = student_features.shape 81 | 82 | levels = [h, h-1, h-2, h-3] 83 | level_weight = [1.0, 0.5, 0.25, 0.125] 84 | total_weight = sum(level_weight) 85 | 86 | for lvl, lvl_weight in zip(levels, level_weight): 87 | if lvl > h: 88 | continue 89 | 90 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl)) 91 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl)) 92 | 93 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight 94 | loss += lvl_loss 95 | 96 | return loss / total_weight 97 | 98 | 99 | def hcl_weight_1(student_features, teacher_features): 100 | loss = 0.0 101 | n, c, h, w = student_features.shape 102 | 103 | levels = [h, 4, 2, 1] 104 | level_weight = [1.0, 1.0, 1.0, 1.0] 105 | total_weight = sum(level_weight) 106 | 107 | for lvl, lvl_weight in zip(levels, level_weight): 108 | if lvl > h: 109 | continue 110 | 111 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl)) 112 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl)) 113 | 114 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight 115 | loss += lvl_loss 116 | 117 | return loss / total_weight 118 | 119 | 120 | def hcl_weight_2(student_features, teacher_features): 121 | loss = 0.0 122 | n, c, h, w = student_features.shape 123 | 124 | levels = [h, 4, 2, 1] 125 | level_weight = [0.125, 0.25, 0.5, 1.0] 126 | total_weight = sum(level_weight) 127 | 128 | for lvl, lvl_weight in zip(levels, level_weight): 129 | if lvl > h: 130 | continue 131 | 132 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl)) 133 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl)) 134 | 135 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight 136 | loss += lvl_loss 137 | 138 | return loss / total_weight 139 | 140 | 141 | def hcl_no_levels_l2(student_features, teacher_features): 142 | loss = 0.0 143 | n, c, h, w = student_features.shape 144 | 145 | levels = [h] 146 | level_weight = [1] 147 | total_weight = sum(level_weight) 148 | 149 | for lvl, lvl_weight in zip(levels, level_weight): 150 | if lvl > h: 151 | continue 152 | 153 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl)) 154 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl)) 155 | 156 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight 157 | loss += lvl_loss 158 | 159 | return loss / total_weight 160 | 161 | 162 | if __name__ == '__main__': 163 | # varying the levels of pooling 164 | train(params, hcl_level_1, ABF, RLF_for_Resnet, log_file_suffix='hcl_level_1') 165 | params["lr"] = 0.1; train(params, hcl_level_2, ABF, RLF_for_Resnet, log_file_suffix='hcl_level_2') 166 | params["lr"] = 0.1; train(params, hcl_level_3, ABF, RLF_for_Resnet, log_file_suffix='hcl_level_3') 167 | params["lr"] = 0.1; train(params, hcl_level_4, ABF, RLF_for_Resnet, log_file_suffix='hcl_level_4') 168 | 169 | # # varying the weights assigned to each level 170 | params["lr"] = 0.1; train(params, hcl_weight_1, ABF, RLF_for_Resnet, log_file_suffix='hcl_weight_1') 171 | params["lr"] = 0.1; train(params, hcl_weight_2, ABF, RLF_for_Resnet, log_file_suffix='hcl_weight_2') 172 | -------------------------------------------------------------------------------- /framework.py: -------------------------------------------------------------------------------- 1 | # framework.py consists of implementations of the proposals made by the authors 2 | # We largely refer the original implementation of the paper by the authors, 3 | # mainly refactoring their code 4 | 5 | # Distilling Knowledge via Knowledge Review 6 | # |----> Uses Residual Learning Framework 7 | # |----> Uses Hierarchical Context Loss 8 | # |----> Uses Attention Based Fusion Module 9 | 10 | # Let us start from the bottom going upwards 11 | 12 | # We implement the framework for general ResNet architectures only 13 | 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from torch import nn 18 | 19 | 20 | ########## Hierarchical Context Loss ########## 21 | 22 | def hcl(student_features, teacher_features): 23 | loss = 0.0 24 | n, c, h, w = student_features.shape 25 | 26 | # the levels of hcl loss here are predefined, according to 27 | # the authors' implementation 28 | # ablation studies have been performed and these levels and their 29 | # weights have been changed (refer experimental/hcl_experiments.py) 30 | levels = [h, 4, 2, 1] 31 | level_weight = [1.0, 0.5, 0.25, 0.125] 32 | total_weight = sum(level_weight) 33 | # keeping the corresponding weights same, we change the levels to: 34 | # [h, 2, 1], [4, 1], [h, h//2, h//4], [h, h-1, h-2, h-3] 35 | # keeping the levels same, we change the weights to: 36 | # [1.0, 1.0, 1.0, 1.0], [0.125, 0.25, 0.5, 1.0] 37 | 38 | for lvl, lvl_weight in zip(levels, level_weight): 39 | if lvl > h: 40 | continue 41 | 42 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl)) 43 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl)) 44 | 45 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight 46 | loss += lvl_loss 47 | 48 | return loss / total_weight 49 | 50 | 51 | ########## Attention Based Fusion Module ########## 52 | 53 | # Paper states that the output from the ABF module (single output as 54 | # presented in the ABF flow diagram, fig. 3(a)) is the one of the inputs to 55 | # the next ABF module. 56 | 57 | # But the authors' code implementation provides two different outputs, one that 58 | # proceeds to the next ABF module (`residual_output`) and one that 59 | # is the output of the ABF module and which is involved in the loss 60 | # function (`abf_output`) 61 | # The `residual_output` differs from the `abf_output` in terms of the number 62 | # of channels. The `residual_output` has `mid_channels` while the `abf_output` 63 | # has `out_channels` 64 | 65 | # In this implementation, we have taken the latter approach 66 | 67 | # The second approach can be found in experimental/abf_experiments.py 68 | 69 | class ABF(nn.Module): 70 | def __init__(self, in_channel, out_channel): 71 | super(ABF, self).__init__() 72 | 73 | self.mid_channel = 64 74 | 75 | self.conv_to_mid_channel = nn.Sequential( 76 | nn.Conv2d(in_channel, self.mid_channel, kernel_size=1, bias=False), 77 | nn.BatchNorm2d(self.mid_channel), 78 | ).to(torch.device('cuda:0')) 79 | nn.init.kaiming_uniform_(self.conv_to_mid_channel[0].weight, a=1) 80 | 81 | self.conv_to_out_channel = nn.Sequential( 82 | nn.Conv2d(self.mid_channel, out_channel, kernel_size=3, 83 | stride=1, padding=1, bias=False), 84 | nn.BatchNorm2d(out_channel), 85 | ).to(torch.device('cuda:0')) 86 | nn.init.kaiming_uniform_(self.conv_to_out_channel[0].weight, a=1) 87 | 88 | self.conv_to_att_maps = nn.Sequential( 89 | nn.Conv2d(self.mid_channel * 2, 2, kernel_size=1), 90 | nn.Sigmoid(), 91 | ).to(torch.device('cuda:0')) 92 | nn.init.kaiming_uniform_(self.conv_to_att_maps[0].weight, a=1) 93 | 94 | def forward(self, student_feature, prev_abf_output, teacher_shape): 95 | n, c, h, w = student_feature.shape 96 | student_feature = self.conv_to_mid_channel(student_feature) 97 | 98 | if prev_abf_output is None: 99 | residual_output = student_feature 100 | else: 101 | prev_abf_output = F.interpolate(prev_abf_output, size=( 102 | teacher_shape, teacher_shape), mode='nearest') 103 | 104 | concat_features = torch.cat( 105 | [student_feature, prev_abf_output], dim=1) 106 | attention_maps = self.conv_to_att_maps(concat_features) 107 | attention_map1 = attention_maps[:, 0].view(n, 1, h, w) 108 | attention_map2 = attention_maps[:, 1].view(n, 1, h, w) 109 | 110 | residual_output = student_feature * attention_map1 \ 111 | + prev_abf_output * attention_map2 112 | 113 | # the output of the abf is obtained after the residual 114 | # output is convolved to have `out_channels` channels 115 | abf_output = self.conv_to_out_channel(residual_output) 116 | 117 | return abf_output, residual_output 118 | 119 | 120 | ########## Residual Learning Framework ########## 121 | 122 | class RLF_for_Resnet(nn.Module): 123 | def __init__(self, student, abf_to_use): 124 | super(RLF_for_Resnet, self).__init__() 125 | 126 | self.student = student 127 | 128 | in_channels = [16, 32, 64, 64] 129 | out_channels = [16, 32, 64, 64] 130 | 131 | self.shapes = [1, 8, 16, 32, 32] 132 | 133 | ABFs = nn.ModuleList() 134 | 135 | for idx, in_channel in enumerate(in_channels): 136 | ABFs.append(abf_to_use(in_channel, out_channels[idx])) 137 | 138 | self.ABFs = ABFs[::-1] 139 | self.to('cuda') 140 | 141 | def forward(self, x): 142 | student_features = self.student(x, is_feat=True) 143 | 144 | student_preds = student_features[1] 145 | student_features = student_features[0][::-1] 146 | 147 | results = [] 148 | 149 | abf_output, residual_output = self.ABFs[0]( 150 | student_features[0], None, self.shapes[0]) 151 | 152 | results.append(abf_output) 153 | 154 | for features, abf, shape in zip(student_features[1:], self.ABFs[1:], self.shapes[1:]): 155 | # here we use a recursive technique to obtain all the ABF 156 | # outputs and store them in a list 157 | abf_output, residual_output = abf(features, residual_output, shape) 158 | results.insert(0, abf_output) 159 | 160 | return results, student_preds 161 | -------------------------------------------------------------------------------- /experimental/abf_experiments.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '..') 3 | 4 | 5 | from train import train 6 | from params import params 7 | from hcl_experiments import * 8 | from abf_experiments import * 9 | from framework import ABF, hcl 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | 14 | 15 | class ABF_without_mid_channels(nn.Module): 16 | def __init__(self, in_channel, out_channel, pabf_channel): 17 | super(ABF_without_mid_channels, self).__init__() 18 | 19 | self.conv_to_out_channel_sf = nn.Sequential( 20 | nn.Conv2d(in_channel, out_channel, kernel_size=3, 21 | stride=1, padding=1, bias=False), 22 | nn.BatchNorm2d(out_channel), 23 | ).to(torch.device('cuda:0')) 24 | nn.init.kaiming_uniform_(self.conv_to_out_channel_sf[0].weight, a=1) 25 | 26 | self.conv_to_out_channel_pabf = nn.Sequential( 27 | nn.Conv2d(pabf_channel, out_channel, kernel_size=3, 28 | stride=1, padding=1, bias=False), 29 | nn.BatchNorm2d(out_channel), 30 | ).to(torch.device('cuda:0')) 31 | nn.init.kaiming_uniform_(self.conv_to_out_channel_pabf[0].weight, a=1) 32 | 33 | self.conv_to_att_maps = nn.Sequential( 34 | nn.Conv2d(out_channel * 2, 2, kernel_size=1), 35 | nn.Sigmoid(), 36 | ).to(torch.device('cuda:0')) 37 | nn.init.kaiming_uniform_(self.conv_to_att_maps[0].weight, a=1) 38 | 39 | def forward(self, student_feature, prev_abf_output, teacher_shape): 40 | n, c, h, w = student_feature.shape 41 | student_feature = self.conv_to_out_channel_sf(student_feature) 42 | 43 | if prev_abf_output is None: 44 | residual_output = student_feature 45 | else: 46 | print(prev_abf_output.shape) 47 | prev_abf_output = self.conv_to_out_channel_pabf(prev_abf_output) 48 | prev_abf_output = F.interpolate(prev_abf_output, size=( 49 | teacher_shape, teacher_shape), mode='nearest') 50 | 51 | concat_features = torch.cat( 52 | [student_feature, prev_abf_output], dim=1) 53 | attention_maps = self.conv_to_att_maps(concat_features) 54 | attention_map1 = attention_maps[:, 0].view(n, 1, h, w) 55 | attention_map2 = attention_maps[:, 1].view(n, 1, h, w) 56 | 57 | residual_output = student_feature * attention_map1 \ 58 | + prev_abf_output * attention_map2 59 | 60 | # here we just equate both the outputs instead of having 61 | # a single output to have the same training code for both 62 | # the implementations 63 | abf_output = residual_output 64 | 65 | return abf_output, residual_output 66 | 67 | 68 | class ABF_without_attention_maps(nn.Module): 69 | def __init__(self, in_channel, out_channel): 70 | super(ABF_without_attention_maps, self).__init__() 71 | 72 | self.mid_channel = 64 73 | 74 | self.conv_to_mid_channel = nn.Sequential( 75 | nn.Conv2d(in_channel, self.mid_channel, kernel_size=1, bias=False), 76 | nn.BatchNorm2d(self.mid_channel), 77 | ).to(torch.device('cuda:0')) 78 | nn.init.kaiming_uniform_(self.conv_to_mid_channel[0].weight, a=1) 79 | 80 | self.conv_to_out_channel = nn.Sequential( 81 | nn.Conv2d(self.mid_channel, out_channel, kernel_size=3, 82 | stride=1, padding=1, bias=False), 83 | nn.BatchNorm2d(out_channel), 84 | ).to(torch.device('cuda:0')) 85 | nn.init.kaiming_uniform_(self.conv_to_out_channel[0].weight, a=1) 86 | 87 | self.conv_to_att_maps = nn.Sequential( 88 | nn.Conv2d(self.mid_channel * 2, 2, kernel_size=1), 89 | nn.Sigmoid(), 90 | ).to(torch.device('cuda:0')) 91 | nn.init.kaiming_uniform_(self.conv_to_att_maps[0].weight, a=1) 92 | 93 | def forward(self, student_feature, prev_abf_output, teacher_shape): 94 | n, c, h, w = student_feature.shape 95 | student_feature = self.conv_to_mid_channel(student_feature) 96 | 97 | if prev_abf_output is None: 98 | residual_output = student_feature 99 | else: 100 | prev_abf_output = F.interpolate(prev_abf_output, size=( 101 | teacher_shape, teacher_shape), mode='nearest') 102 | 103 | residual_output = student_feature + prev_abf_output 104 | 105 | # the output of the abf is obtained after the residual 106 | # output is convolved to have `out_channels` channels 107 | abf_output = self.conv_to_out_channel(residual_output) 108 | 109 | return abf_output, residual_output 110 | 111 | 112 | class RLF_for_Resnet_with_ABF_without_mid_channels(nn.Module): 113 | def __init__(self, student, abf_to_use): 114 | super(RLF_for_Resnet_with_ABF_without_mid_channels, self).__init__() 115 | 116 | self.student = student 117 | 118 | in_channels = [16, 32, 64, 64] 119 | out_channels = [16, 32, 64, 64] 120 | pabf_channels = [1, 16, 32, 64] 121 | 122 | self.shapes = [1, 8, 16, 32, 32] 123 | 124 | ABFs = nn.ModuleList() 125 | 126 | for idx, in_channel in enumerate(in_channels): 127 | ABFs.append(abf_to_use(in_channel, out_channels[idx], pabf_channels[idx])) 128 | 129 | self.ABFs = ABFs[::-1] 130 | self.to('cuda') 131 | 132 | def forward(self, x): 133 | student_features = self.student(x, is_feat=True) 134 | 135 | student_preds = student_features[1] 136 | student_features = student_features[0][::-1] 137 | 138 | results = [] 139 | 140 | abf_output, residual_output = self.ABFs[0]( 141 | student_features[0], None, self.shapes[0]) 142 | 143 | results.append(abf_output) 144 | 145 | for features, abf, shape in zip(student_features[1:], self.ABFs[1:], self.shapes[1:]): 146 | # here we use a recursive technique to obtain all the ABF 147 | # outputs and store them in a list 148 | abf_output, residual_output = abf(features, residual_output, shape) 149 | results.insert(0, abf_output) 150 | 151 | return results, student_preds 152 | 153 | 154 | if __name__ == '__main__': 155 | train(params, hcl, ABF_without_attention_maps, 156 | log_file_suffix='abf_without_attention_maps') 157 | params['lr'] = 0.1 158 | train(params, hcl, ABF_without_mid_channels, RLF_for_Resnet_with_ABF_without_mid_channels, 159 | log_file_suffix='abf_without_mid_channels') 160 | -------------------------------------------------------------------------------- /utils/resnets_for_cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import math 13 | 14 | 15 | __all__ = ['resnet'] 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 28 | super(BasicBlock, self).__init__() 29 | self.is_last = is_last 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | preact = out 53 | out = F.relu(out) 54 | if self.is_last: 55 | return out, preact 56 | else: 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 64 | super(Bottleneck, self).__init__() 65 | self.is_last = is_last 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * 4) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | preact = out 96 | out = F.relu(out) 97 | if self.is_last: 98 | return out, preact 99 | else: 100 | return out 101 | 102 | 103 | class ResNet(nn.Module): 104 | 105 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10): 106 | super(ResNet, self).__init__() 107 | # Model type specifies number of layers for CIFAR-10 model 108 | if block_name.lower() == 'basicblock': 109 | assert ( 110 | depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 111 | n = (depth - 2) // 6 112 | block = BasicBlock 113 | elif block_name.lower() == 'bottleneck': 114 | assert ( 115 | depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 116 | n = (depth - 2) // 9 117 | block = Bottleneck 118 | else: 119 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 120 | 121 | self.inplanes = num_filters[0] 122 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, 123 | bias=False) 124 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 125 | self.relu = nn.ReLU(inplace=True) 126 | self.layer1 = self._make_layer(block, num_filters[1], n) 127 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2) 128 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2) 129 | self.avgpool = nn.AvgPool2d(8) 130 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes) 131 | 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | nn.init.kaiming_normal_( 135 | m.weight, mode='fan_out', nonlinearity='relu') 136 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 137 | nn.init.constant_(m.weight, 1) 138 | nn.init.constant_(m.bias, 0) 139 | self.to('cuda') 140 | 141 | def _make_layer(self, block, planes, blocks, stride=1): 142 | downsample = None 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | downsample = nn.Sequential( 145 | nn.Conv2d(self.inplanes, planes * block.expansion, 146 | kernel_size=1, stride=stride, bias=False), 147 | nn.BatchNorm2d(planes * block.expansion), 148 | ) 149 | 150 | layers = list([]) 151 | layers.append(block(self.inplanes, planes, stride, 152 | downsample, is_last=(blocks == 1))) 153 | self.inplanes = planes * block.expansion 154 | for i in range(1, blocks): 155 | layers.append(block(self.inplanes, planes, 156 | is_last=(i == blocks-1))) 157 | 158 | return nn.Sequential(*layers) 159 | 160 | def get_feat_modules(self): 161 | feat_m = nn.ModuleList([]) 162 | feat_m.append(self.conv1) 163 | feat_m.append(self.bn1) 164 | feat_m.append(self.relu) 165 | feat_m.append(self.layer1) 166 | feat_m.append(self.layer2) 167 | feat_m.append(self.layer3) 168 | return feat_m 169 | 170 | def get_bn_before_relu(self): 171 | if isinstance(self.layer1[0], Bottleneck): 172 | bn1 = self.layer1[-1].bn3 173 | bn2 = self.layer2[-1].bn3 174 | bn3 = self.layer3[-1].bn3 175 | elif isinstance(self.layer1[0], BasicBlock): 176 | bn1 = self.layer1[-1].bn2 177 | bn2 = self.layer2[-1].bn2 178 | bn3 = self.layer3[-1].bn2 179 | else: 180 | raise NotImplementedError('ResNet unknown block error !!!') 181 | 182 | return [bn1, bn2, bn3] 183 | 184 | def forward(self, x, is_feat=False, preact=False): 185 | x = self.conv1(x) 186 | x = self.bn1(x) 187 | x = self.relu(x) # 32x32 188 | f0 = x 189 | 190 | x, f1_pre = self.layer1(x) # 32x32 191 | f1 = x 192 | x, f2_pre = self.layer2(x) # 16x16 193 | f2 = x 194 | x, f3_pre = self.layer3(x) # 8x8 195 | f3 = x 196 | 197 | x = self.avgpool(x) 198 | f4 = x 199 | x = x.view(x.size(0), -1) 200 | x = self.fc(x) 201 | 202 | if is_feat: 203 | if preact: 204 | return [f0, f1_pre, f2_pre, f3_pre, f4], x 205 | else: 206 | return [f0, f1, f2, f3, f4], x 207 | else: 208 | return x 209 | -------------------------------------------------------------------------------- /experimental/table7_experiments.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '..') 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | import numpy as np 9 | import time 10 | 11 | from params import params 12 | from data import get_dataloaders 13 | from teachers import get_teacher 14 | from students import get_student 15 | from abf_experiments import ABF_without_attention_maps 16 | from framework import RLF_for_Resnet 17 | from utils.misc import AverageMeter, format_time, Logger 18 | from test import test 19 | 20 | 21 | def train_general(params, framework, kd_loss, log_file_suffix=''): 22 | cudnn.deterministic = True 23 | cudnn.benchmark = False 24 | if params["seed"] == 0: 25 | params["seed"] = np.random.randint(1000) 26 | torch.manual_seed(params["seed"]) 27 | np.random.seed(params["seed"]) 28 | torch.cuda.manual_seed(params["seed"]) 29 | 30 | train_loader, test_loader = get_dataloaders( 31 | params["dataset"], params["batch_size"]) 32 | if params["dataset"] == 'cifar10': 33 | num_classes = 10 34 | elif params["dataset"] == 'cifar100': 35 | num_classes = 100 36 | 37 | teacher = get_teacher(params["teacher"], num_classes=num_classes) 38 | student = get_student(params["student"], num_classes=num_classes) 39 | 40 | # load teacher weights from pretrained model 41 | # weight = torch.load(params["teacher_weight_path"]) 42 | # teacher.load_state_dict(weight) 43 | # for p in teacher.parameters(): 44 | # p.requires_grad = False 45 | teacher.to(torch.device('cuda:0')) 46 | 47 | base_loss = nn.CrossEntropyLoss().cuda() 48 | optimizer = torch.optim.SGD( 49 | framework.parameters(), 50 | lr=params["lr"], 51 | momentum=0.9, 52 | nesterov=True, 53 | weight_decay=params["weight_decay"] 54 | ) 55 | 56 | train_log_file = f"logs/{params['dataset'] + '_' + params['student'] + '_' + params['teacher'] + '_' + log_file_suffix}" 57 | logger = Logger(params=params, filename=train_log_file+'.txt') 58 | best_accuracy = 0.0 59 | best_model = framework 60 | 61 | start_time = time.time() 62 | print("starting training with the following params:") 63 | print(params) 64 | print() 65 | 66 | for epoch in range(params["num_epochs"]): 67 | loss_avg = { 68 | 'kd_loss': AverageMeter(), 69 | 'base_loss': AverageMeter() 70 | } 71 | correct_preds = 0.0 72 | total_images = 0.0 73 | 74 | for i, (X, y) in enumerate(train_loader): 75 | X, y = X.cuda(), y.cuda() 76 | 77 | losses = {"kd_loss": 0, "base_loss": 0} 78 | 79 | # getting student and teacher features 80 | # authors use features obtained **after** activation for the student 81 | # (see rlf implementation in framework.py) 82 | student_features, student_preds = framework(X) 83 | # authors use features obtained **before** activation for the teacher (preact=True) 84 | teacher_features, teacher_preds = teacher(X, is_feat=True, preact=True) 85 | 86 | # authors start from the second teacher features 87 | if isinstance(framework, RLF_for_Resnet): 88 | teacher_features = teacher_features[1:] 89 | 90 | # calculating review kd loss 91 | for sf, tf in zip(student_features, teacher_features): 92 | losses['kd_loss'] += kd_loss(sf, tf) 93 | 94 | # calculating cross entropy loss 95 | losses['base_loss'] = base_loss(student_preds, y) 96 | 97 | loss = losses['kd_loss'] * params['kd_loss_weight'] 98 | loss += losses['base_loss'] 99 | 100 | optimizer.zero_grad() 101 | loss.backward() 102 | optimizer.step() 103 | 104 | for key in losses: 105 | loss_avg[key].update(losses[key]) 106 | 107 | # calculate running average of accuracy 108 | student_preds = torch.max(student_preds.data, 1)[1] 109 | total_images += y.size(0) 110 | correct_preds += (student_preds == y.data).sum().item() 111 | train_accuracy = correct_preds / total_images 112 | 113 | # calculating test accuracy and storing best results 114 | test_accuracy = test(framework, test_loader) 115 | if test_accuracy > best_accuracy: 116 | best_accuracy = test_accuracy 117 | best_model = framework 118 | 119 | # decaying lr at scheduled steps 120 | if epoch in params['lr_decay_steps']: 121 | params['lr'] *= params["lr_decay_rate"] 122 | for param_group in optimizer.param_groups: 123 | param_group['lr'] = params['lr'] 124 | 125 | # logging results 126 | loss_avg = {k: loss_avg[k].val for k in loss_avg} 127 | log_row = { 128 | 'epoch': str(epoch), 129 | 'train_acc': '%.2f' % (train_accuracy*100), 130 | 'test_acc': '%.2f' % (test_accuracy*100), 131 | 'best_acc': '%.2f' % (best_accuracy*100), 132 | 'lr': '%.5f' % (params['lr']), 133 | 'loss': '%.5f' % (sum(loss_avg.values())), 134 | 'kd_loss': '%.5f' % loss_avg['kd_loss'], 135 | 'base_loss': '%.5f' % loss_avg['base_loss'], 136 | 'time': format_time(time.time()-start_time), 137 | 'eta': format_time((time.time()-start_time)/(epoch+1)*(params["num_epochs"]-epoch-1)), 138 | } 139 | print(log_row) 140 | logger.writerow(log_row) 141 | 142 | torch.save(best_model.state_dict(), 'pretrained/' + train_log_file + '.pt') 143 | logger.close() 144 | 145 | 146 | class GeneralFusionModule(nn.Module): 147 | def __init__(self, in_channels, out_channel, out_shape): 148 | super(GeneralFusionModule, self).__init__() 149 | self.out_shape = out_shape 150 | 151 | conv_layers = [] 152 | for ch in in_channels: 153 | conv = nn.Sequential( 154 | nn.Conv2d(ch, out_channel, kernel_size=1), 155 | nn.BatchNorm2d(out_channel) 156 | ).to(torch.device('cuda:0')) 157 | nn.init.kaiming_uniform_(conv[0].weight, a=1) 158 | conv_layers.append(conv) 159 | 160 | self.conv_layers = conv_layers 161 | 162 | def forward(self, student_features): 163 | for i, sf in enumerate(student_features): 164 | sf = self.conv_layers[i](sf) 165 | student_features[i] = F.interpolate( 166 | sf, size=(self.out_shape, self.out_shape), mode='nearest') 167 | 168 | output = student_features[0] 169 | for sf in student_features[1:]: 170 | output += sf 171 | 172 | return output 173 | 174 | 175 | class BaselineFramework(nn.Module): 176 | def __init__(self, student): 177 | super(BaselineFramework, self).__init__() 178 | 179 | in_channels = [16, 16, 32, 64, 64] 180 | out_channels = [16, 16, 32, 64, 64] 181 | out_shapes = [32, 32, 16, 8, 1] 182 | 183 | fusion_modules = nn.ModuleList() 184 | 185 | for i in range(len(in_channels)): 186 | fusion_modules.append( 187 | GeneralFusionModule([in_channels[i], ], out_channels[i], out_shapes[i])) 188 | 189 | self.fusion_modules = fusion_modules 190 | 191 | self.student = student 192 | self.to('cuda') 193 | 194 | def forward(self, x): 195 | student_features = self.student(x, is_feat=True) 196 | 197 | student_preds = student_features[1] 198 | student_features = student_features[0] 199 | 200 | results = [] 201 | 202 | for i, fm in enumerate(self.fusion_modules): 203 | results.append(fm([student_features[i], ])) 204 | 205 | return results, student_preds 206 | 207 | 208 | def train_baseline(params, student): 209 | framework = BaselineFramework(student) 210 | return train_general(params, framework, F.mse_loss, log_file_suffix='table7_1') 211 | 212 | 213 | class RMFramework(nn.Module): 214 | def __init__(self, student): 215 | super(RMFramework, self).__init__() 216 | 217 | self.student = student 218 | 219 | in_channels = [16, 16, 32, 64, 64] 220 | out_channels = [16, 16, 32, 64, 64] 221 | out_shapes = [32, 32, 16, 8, 1] 222 | 223 | fusion_modules = nn.ModuleList() 224 | 225 | for i in range(len(in_channels)): 226 | fusion_modules.append( 227 | GeneralFusionModule(in_channels[i:], out_channels[i], out_shapes[i])) 228 | 229 | self.fusion_modules = fusion_modules 230 | self.to('cuda') 231 | 232 | def forward(self, x): 233 | student_features = self.student(x, is_feat=True) 234 | 235 | student_preds = student_features[1] 236 | student_features = student_features[0] 237 | 238 | results = [] 239 | 240 | for i, fm in enumerate(self.fusion_modules): 241 | results.append(fm(student_features[i:])) 242 | 243 | return results, student_preds 244 | 245 | 246 | def train_rm_framework(params, student): 247 | framework = RMFramework(student) 248 | return train_general(params, framework, F.mse_loss, log_file_suffix='table7_2') 249 | 250 | 251 | def train_rlf_framework(params, student): 252 | framework = RLF_for_Resnet(student, ABF_without_attention_maps) 253 | return train_general(params, framework, F.mse_loss, log_file_suffix='table7_3') 254 | 255 | 256 | if __name__ == '__main__': 257 | student = get_student(params["student"], num_classes=10) 258 | train_rlf_framework(params, student) 259 | params['lr'] = 0.1; train_rm_framework(params, student) 260 | params['lr'] = 0.1; train_baseline(params, student) --------------------------------------------------------------------------------