├── __init__.py
├── requirements.txt
├── review-mechanism.jpg
├── params.py
├── test.py
├── teachers.py
├── students.py
├── data.py
├── utils
├── misc.py
└── resnets_for_cifar.py
├── README.md
├── train.py
├── experimental
├── hcl_experiments.py
├── abf_experiments.py
└── table7_experiments.py
└── framework.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.22.0
2 | torch==1.10.2
3 | torchvision==0.11.3
4 |
--------------------------------------------------------------------------------
/review-mechanism.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DevPranjal/reproduction-review-kd/HEAD/review-mechanism.jpg
--------------------------------------------------------------------------------
/params.py:
--------------------------------------------------------------------------------
1 | params = {
2 | "dataset": "cifar10",
3 | "student": "resnet20",
4 | "teacher": "resnet56",
5 | "teacher_weight_path": f"./pretrained/resnet56.pt",
6 |
7 | "batch_size": 64,
8 | "num_epochs": 20,
9 | "lr": 0.1,
10 | "lr_decay_steps": [12, 17],
11 | "lr_decay_rate": 0.1,
12 | "weight_decay": 5e-4,
13 | "args": 0,
14 |
15 | "kd_loss_weight": 0.6,
16 |
17 | "seed": 0,
18 | }
19 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def test(net, test_loader):
5 | net.eval()
6 | correct_preds = 0.
7 | total_images = 0.
8 | for images, labels in test_loader:
9 | images, labels = images.cuda(), labels.cuda()
10 |
11 | with torch.no_grad():
12 | features, preds = net(images)
13 |
14 | preds = torch.max(preds.data, 1)[1]
15 | total_images += labels.size(0)
16 | correct_preds += (preds == labels).sum().item()
17 |
18 | test_acc = correct_preds / total_images
19 | net.train()
20 | return test_acc
21 |
--------------------------------------------------------------------------------
/teachers.py:
--------------------------------------------------------------------------------
1 | from utils.resnets_for_cifar import ResNet
2 |
3 |
4 | def resnet44(**kwargs):
5 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs)
6 |
7 |
8 | def resnet56(**kwargs):
9 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs)
10 |
11 |
12 | def resnet110(**kwargs):
13 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs)
14 |
15 |
16 | def get_teacher(name, **kwargs):
17 | if name == 'resnet44':
18 | return resnet44(**kwargs)
19 | elif name == 'resnet56':
20 | return resnet56(**kwargs)
21 | elif name == 'resnet110':
22 | return resnet110(**kwargs)
23 |
--------------------------------------------------------------------------------
/students.py:
--------------------------------------------------------------------------------
1 | from utils.resnets_for_cifar import ResNet
2 |
3 |
4 | def resnet8(**kwargs):
5 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs)
6 |
7 |
8 | def resnet14(**kwargs):
9 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs)
10 |
11 |
12 | def resnet20(**kwargs):
13 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs)
14 |
15 |
16 | def resnet32(**kwargs):
17 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs)
18 |
19 |
20 | def get_student(name, **kwargs):
21 | if name == 'resnet8':
22 | return resnet8(**kwargs)
23 | elif name == 'resnet14':
24 | return resnet14(**kwargs)
25 | elif name == 'resnet20':
26 | return resnet20(**kwargs)
27 | elif name == 'resnet32':
28 | return resnet32(**kwargs)
29 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 |
4 |
5 | normalize = transforms.Normalize(
6 | mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
7 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
8 | )
9 |
10 | train_transform = transforms.Compose([
11 | transforms.RandomCrop(32, padding=4),
12 | transforms.RandomHorizontalFlip(),
13 | transforms.ToTensor(),
14 | normalize
15 | ])
16 |
17 | test_transform = transforms.Compose([
18 | transforms.ToTensor(),
19 | normalize
20 | ])
21 |
22 |
23 | def get_dataloaders(dataset, batch_size):
24 | if dataset == 'cifar10':
25 | train_dataset = datasets.CIFAR10(
26 | root='data/', train=True, transform=train_transform, download=True)
27 | test_dataset = datasets.CIFAR10(
28 | root='data/', train=False, transform=test_transform, download=True)
29 | elif dataset == 'cifar100':
30 | train_dataset = datasets.CIFAR100(
31 | root='data/', train=True, transform=train_transform, download=True)
32 | test_dataset = datasets.CIFAR100(
33 | root='data/', train=False, transform=test_transform, download=True)
34 |
35 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size,
36 | shuffle=True, pin_memory=True, num_workers=2)
37 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size,
38 | shuffle=False, pin_memory=True, num_workers=1)
39 |
40 | return train_loader, test_loader
41 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | """Computes and stores the average and current value
3 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
4 | """
5 |
6 | def __init__(self):
7 | self.reset()
8 |
9 | def reset(self):
10 | self.val = 0
11 | self.avg = 0
12 | self.sum = 0
13 | self.count = 0
14 |
15 | def update(self, val, n=1):
16 | self.val = val
17 | self.sum += val * n
18 | self.count += n
19 | self.avg = self.sum / self.count
20 |
21 |
22 | def format_time(seconds):
23 | days = int(seconds / 3600/24)
24 | seconds = seconds - days*3600*24
25 | hours = int(seconds / 3600)
26 | seconds = seconds - hours*3600
27 | minutes = int(seconds / 60)
28 | seconds = seconds - minutes*60
29 | secondsf = int(seconds)
30 | seconds = seconds - secondsf
31 | millis = int(seconds*1000)
32 |
33 | f = ''
34 | i = 1
35 | if days > 0:
36 | f += str(days) + 'D'
37 | i += 1
38 | if hours > 0 and i <= 2:
39 | f += str(hours) + 'h'
40 | i += 1
41 | if minutes > 0 and i <= 2:
42 | f += str(minutes) + 'm'
43 | i += 1
44 | if secondsf > 0 and i <= 2:
45 | f += str(secondsf) + 's'
46 | i += 1
47 | if millis > 0 and i <= 2:
48 | f += str(millis) + 'ms'
49 | i += 1
50 | if f == '':
51 | f = '0ms'
52 | return f
53 |
54 |
55 | class Logger():
56 | def __init__(self, params, filename='log.txt'):
57 | self.filename = filename
58 | self.file = open(filename, 'w')
59 | # Write model configuration at top of file
60 | for key, value in params.items():
61 | self.file.write(f'{key}: {value}\n')
62 | self.file.flush()
63 |
64 | def writerow(self, row):
65 | for k in row:
66 | self.file.write(k+': '+row[k]+' ')
67 | self.file.write('\n')
68 | self.file.flush()
69 |
70 | def close(self):
71 | self.file.close()
72 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ML Reproducibility Challenge 2021: Fall Edition
2 |
3 | ### Submission for [Distilling Knowledge via Knowledge Review](https://arxiv.org/abs/2104.09044) published in CVPR 2021
4 |
5 |
6 |
7 | ---
8 |
9 | This effort aims to reproduce the results of experiments and analyse the robustness of the review framework for knowledge distillation introduced in the original paper. We verify the improvement in test accuracy of consistently across student models as reported and study the effectiveness of the novel modules introduced by the authors by conducting ablation studies and new experiments.
10 |
11 | ---
12 |
13 | ### Setting up environment
14 |
15 | ```bash
16 | conda create -n reviewkd python=3.8
17 | conda activate reviewkd
18 |
19 | git clone https://github.com/DevPranjal/ml-repro-2021.git
20 | cd ml-repro-2021
21 |
22 | pip install requirements.txt
23 | ```
24 |
25 | ### Training baseline teachers
26 |
27 | To train the teacher model, we use the code written by the authors as follows:
28 |
29 | ```bash
30 | git clone https://github.com/dvlab-research/ReviewKD
31 | cd ReviewKD/CIFAR100
32 | python train.py --model resnet56
33 | ```
34 |
35 | ### Training student via review mechanism
36 |
37 | To train the student model, we have designed `params.py` for all the settings that can be tuned. After setting the desired values for each key, run the following within the `ml-repro-2021` directory
38 |
39 | ```bash
40 | python train.py
41 | ```
42 |
43 | ### Performing ablation studies and experiments
44 |
45 | The ablation studies and experiments have been organized and implemented in `experimental/`. To execute any of them, run the following command:
46 |
47 | ```bash
48 | cd experimental
49 | python table7_experiments.py
50 | ```
51 |
52 | ### Reproduction Results
53 |
54 | #### Classification results when student and teacher have architectures of the same style
55 |
56 | | Student | ResNet20 | ResNet32 | ReNet8x4 | WRN16-2 | WRN40-1 |
57 | | --------------------- | -------- | --------- | --------- | ------- | ------- |
58 | | Teacher | ResNet56 | ResNet110 | ReNet32x4 | WRN40-2 | WRN40-2 |
59 | | Review KD (Paper) | 71.89 | 73.89 | 75.63 | 76.12 | 75.09 |
60 | | Review KD (Ours) | 71.79 | 73.61 | 76.02 | 76.27 | 75.21 |
61 | | Review KD Loss Weight | 0.7 | 1.0 | 5.0 | 5.0 | 5.0 |
62 |
63 | #### Classification results when student and teacher have architectures of different styles
64 |
65 | | Student | ShuffleNetV1 | ShuffleNetV1 | ShuffleNetV2 |
66 | | --------------------- | ------------ | ------------ | ------------ |
67 | | Teacher | ResNet32x4 | WRN40-2 | ReNet32x4 |
68 | | Review KD (Paper) | 77.45 | 77.14 | 77.78 |
69 | | Review KD (Ours) | 76.94 | 77.44 | 77.86 |
70 | | Review KD Loss Weight | 5.0 | 5.0 | 8.0 |
71 |
72 | #### Adding architectural components one by one
73 |
74 | RM - Review Mechanism
75 |
76 | RLF - Residual Learning Framework
77 |
78 | ABF - Attention Based Fusion
79 |
80 | HCL - Hierarchical Context Loss
81 |
82 | | RM | RLF | ABF | HCL | Test Accuracy |
83 | | --- | --- | --- | --- | ------------- |
84 | | | | | | 69.50 |
85 | | Y | | | | 69.53 |
86 | | Y | Y | | | 69.92 |
87 | | Y | Y | Y | | 71.28 |
88 | | Y | Y | | Y | 71.51 |
89 | | Y | Y | Y | Y | 71.79 |
90 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.backends.cudnn as cudnn
4 | import numpy as np
5 | import time
6 |
7 | from params import params
8 | from data import get_dataloaders
9 | from teachers import get_teacher
10 | from students import get_student
11 | from framework import RLF_for_Resnet, ABF, hcl
12 | from utils.misc import AverageMeter, format_time, Logger
13 | from test import test
14 |
15 |
16 | def train(params, hcl, abf, RLF, log_file_suffix=''):
17 | cudnn.deterministic = True
18 | cudnn.benchmark = False
19 | if params["seed"] == 0:
20 | params["seed"] = np.random.randint(1000)
21 | torch.manual_seed(params["seed"])
22 | np.random.seed(params["seed"])
23 | torch.cuda.manual_seed(params["seed"])
24 |
25 |
26 | train_loader, test_loader = get_dataloaders(params["dataset"], params["batch_size"])
27 | if params["dataset"] == 'cifar10':
28 | num_classes = 10
29 | elif params["dataset"] == 'cifar100':
30 | num_classes = 100
31 |
32 |
33 | teacher = get_teacher(params["teacher"], num_classes=num_classes)
34 | student = get_student(params["student"], num_classes=num_classes)
35 |
36 |
37 | # build the framework for student to be trained
38 | rlf = RLF(student, abf_to_use=abf) # rlf => residual learning framework
39 | # load teacher weights from pretrained model
40 | weight = torch.load(params["teacher_weight_path"])
41 | teacher.load_state_dict(weight)
42 | for p in teacher.parameters():
43 | p.requires_grad = False
44 | teacher.to(torch.device('cuda:0'))
45 |
46 |
47 | base_loss = nn.CrossEntropyLoss().cuda()
48 | optimizer = torch.optim.SGD(
49 | rlf.parameters(),
50 | lr=params["lr"],
51 | momentum=0.9,
52 | nesterov=True,
53 | weight_decay=params["weight_decay"]
54 | )
55 |
56 |
57 | train_log_file = f"logs/{params['dataset'] + '_' + params['student'] + '_' + params['teacher'] + '_' + log_file_suffix}"
58 | logger = Logger(params=params, filename=train_log_file+'.txt')
59 | best_accuracy = 0.0
60 | best_model = rlf
61 |
62 |
63 | start_time = time.time()
64 | print("starting training with the following params:")
65 | print(params)
66 | print()
67 |
68 |
69 | for epoch in range(params["num_epochs"]):
70 | loss_avg = {
71 | 'kd_loss': AverageMeter(),
72 | 'base_loss': AverageMeter()
73 | }
74 | correct_preds = 0.0
75 | total_images = 0.0
76 |
77 | for i, (X, y) in enumerate(train_loader):
78 | X, y = X.cuda(), y.cuda()
79 |
80 | losses = {"kd_loss": 0, "base_loss": 0}
81 |
82 | # getting student and teacher features
83 | student_features, student_preds = rlf(X)
84 | teacher_features, teacher_preds = teacher(X, is_feat=True, preact=True)
85 |
86 | teacher_features = teacher_features[1:]
87 |
88 | # calculating review kd loss
89 | for sf, tf in zip(student_features, teacher_features):
90 | losses['kd_loss'] += hcl(sf, tf)
91 |
92 | # calculating cross entropy loss
93 | losses['base_loss'] = base_loss(student_preds, y)
94 |
95 | loss = losses['kd_loss'] * params['kd_loss_weight']
96 | loss += losses['base_loss']
97 |
98 | optimizer.zero_grad()
99 | loss.backward()
100 | optimizer.step()
101 |
102 | for key in losses:
103 | loss_avg[key].update(losses[key])
104 |
105 | # calculate running average of accuracy
106 | student_preds = torch.max(student_preds.data, 1)[1]
107 | total_images += y.size(0)
108 | correct_preds += (student_preds == y.data).sum().item()
109 | train_accuracy = correct_preds / total_images
110 |
111 | # calculating test accuracy and storing best results
112 | test_accuracy = test(rlf, test_loader)
113 | if test_accuracy > best_accuracy:
114 | best_accuracy = test_accuracy
115 | best_model = rlf
116 |
117 | # decaying lr at scheduled steps
118 | if epoch in params['lr_decay_steps']:
119 | params['lr'] *= params["lr_decay_rate"]
120 | for param_group in optimizer.param_groups:
121 | param_group['lr'] = params['lr']
122 |
123 | # logging results
124 | loss_avg = {k: loss_avg[k].val for k in loss_avg}
125 | log_row = {
126 | 'epoch': str(epoch),
127 | 'train_acc': '%.2f' % (train_accuracy*100),
128 | 'test_acc': '%.2f' % (test_accuracy*100),
129 | 'best_acc': '%.2f' % (best_accuracy*100),
130 | 'lr': '%.5f' % (params['lr']),
131 | 'loss': '%.5f' % (sum(loss_avg.values())),
132 | 'kd_loss': '%.5f' % loss_avg['kd_loss'],
133 | 'base_loss': '%.5f' % loss_avg['base_loss'],
134 | 'time': format_time(time.time()-start_time),
135 | 'eta': format_time((time.time()-start_time)/(epoch+1)*(params["num_epochs"]-epoch-1)),
136 | }
137 | print(log_row)
138 | logger.writerow(log_row)
139 |
140 | torch.save(best_model.state_dict(), 'pretrained/' + train_log_file + '.pt')
141 | logger.close()
142 |
143 |
144 | if __name__ == '__main__':
145 | train(params, hcl, ABF, RLF_for_Resnet)
--------------------------------------------------------------------------------
/experimental/hcl_experiments.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '..')
3 |
4 |
5 | from train import train
6 | from params import params
7 | from hcl_experiments import *
8 | from abf_experiments import *
9 | from framework import ABF, hcl, RLF_for_Resnet
10 | import torch
11 | import torch.nn.functional as F
12 | from torch import nn
13 |
14 |
15 | def hcl_level_1(student_features, teacher_features):
16 | loss = 0.0
17 | n, c, h, w = student_features.shape
18 |
19 | levels = [h, 2, 1]
20 | level_weight = [1.0, 0.5, 0.25]
21 | total_weight = sum(level_weight)
22 |
23 | for lvl, lvl_weight in zip(levels, level_weight):
24 | if lvl > h:
25 | continue
26 |
27 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl))
28 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl))
29 |
30 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight
31 | loss += lvl_loss
32 |
33 | return loss / total_weight
34 |
35 |
36 | def hcl_level_2(student_features, teacher_features):
37 | loss = 0.0
38 | n, c, h, w = student_features.shape
39 |
40 | levels = [4, 1]
41 | level_weight = [1.0, 0.5]
42 | total_weight = sum(level_weight)
43 |
44 | for lvl, lvl_weight in zip(levels, level_weight):
45 | if lvl > h:
46 | continue
47 |
48 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl))
49 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl))
50 |
51 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight
52 | loss += lvl_loss
53 |
54 | return loss / total_weight
55 |
56 |
57 | def hcl_level_3(student_features, teacher_features):
58 | loss = 0.0
59 | n, c, h, w = student_features.shape
60 |
61 | levels = [h, h//2, h//4]
62 | level_weight = [1.0, 0.5, 0.25]
63 | total_weight = sum(level_weight)
64 |
65 | for lvl, lvl_weight in zip(levels, level_weight):
66 | if lvl > h:
67 | continue
68 |
69 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl))
70 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl))
71 |
72 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight
73 | loss += lvl_loss
74 |
75 | return loss / total_weight
76 |
77 |
78 | def hcl_level_4(student_features, teacher_features):
79 | loss = 0.0
80 | n, c, h, w = student_features.shape
81 |
82 | levels = [h, h-1, h-2, h-3]
83 | level_weight = [1.0, 0.5, 0.25, 0.125]
84 | total_weight = sum(level_weight)
85 |
86 | for lvl, lvl_weight in zip(levels, level_weight):
87 | if lvl > h:
88 | continue
89 |
90 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl))
91 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl))
92 |
93 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight
94 | loss += lvl_loss
95 |
96 | return loss / total_weight
97 |
98 |
99 | def hcl_weight_1(student_features, teacher_features):
100 | loss = 0.0
101 | n, c, h, w = student_features.shape
102 |
103 | levels = [h, 4, 2, 1]
104 | level_weight = [1.0, 1.0, 1.0, 1.0]
105 | total_weight = sum(level_weight)
106 |
107 | for lvl, lvl_weight in zip(levels, level_weight):
108 | if lvl > h:
109 | continue
110 |
111 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl))
112 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl))
113 |
114 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight
115 | loss += lvl_loss
116 |
117 | return loss / total_weight
118 |
119 |
120 | def hcl_weight_2(student_features, teacher_features):
121 | loss = 0.0
122 | n, c, h, w = student_features.shape
123 |
124 | levels = [h, 4, 2, 1]
125 | level_weight = [0.125, 0.25, 0.5, 1.0]
126 | total_weight = sum(level_weight)
127 |
128 | for lvl, lvl_weight in zip(levels, level_weight):
129 | if lvl > h:
130 | continue
131 |
132 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl))
133 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl))
134 |
135 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight
136 | loss += lvl_loss
137 |
138 | return loss / total_weight
139 |
140 |
141 | def hcl_no_levels_l2(student_features, teacher_features):
142 | loss = 0.0
143 | n, c, h, w = student_features.shape
144 |
145 | levels = [h]
146 | level_weight = [1]
147 | total_weight = sum(level_weight)
148 |
149 | for lvl, lvl_weight in zip(levels, level_weight):
150 | if lvl > h:
151 | continue
152 |
153 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl))
154 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl))
155 |
156 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight
157 | loss += lvl_loss
158 |
159 | return loss / total_weight
160 |
161 |
162 | if __name__ == '__main__':
163 | # varying the levels of pooling
164 | train(params, hcl_level_1, ABF, RLF_for_Resnet, log_file_suffix='hcl_level_1')
165 | params["lr"] = 0.1; train(params, hcl_level_2, ABF, RLF_for_Resnet, log_file_suffix='hcl_level_2')
166 | params["lr"] = 0.1; train(params, hcl_level_3, ABF, RLF_for_Resnet, log_file_suffix='hcl_level_3')
167 | params["lr"] = 0.1; train(params, hcl_level_4, ABF, RLF_for_Resnet, log_file_suffix='hcl_level_4')
168 |
169 | # # varying the weights assigned to each level
170 | params["lr"] = 0.1; train(params, hcl_weight_1, ABF, RLF_for_Resnet, log_file_suffix='hcl_weight_1')
171 | params["lr"] = 0.1; train(params, hcl_weight_2, ABF, RLF_for_Resnet, log_file_suffix='hcl_weight_2')
172 |
--------------------------------------------------------------------------------
/framework.py:
--------------------------------------------------------------------------------
1 | # framework.py consists of implementations of the proposals made by the authors
2 | # We largely refer the original implementation of the paper by the authors,
3 | # mainly refactoring their code
4 |
5 | # Distilling Knowledge via Knowledge Review
6 | # |----> Uses Residual Learning Framework
7 | # |----> Uses Hierarchical Context Loss
8 | # |----> Uses Attention Based Fusion Module
9 |
10 | # Let us start from the bottom going upwards
11 |
12 | # We implement the framework for general ResNet architectures only
13 |
14 |
15 | import torch
16 | import torch.nn.functional as F
17 | from torch import nn
18 |
19 |
20 | ########## Hierarchical Context Loss ##########
21 |
22 | def hcl(student_features, teacher_features):
23 | loss = 0.0
24 | n, c, h, w = student_features.shape
25 |
26 | # the levels of hcl loss here are predefined, according to
27 | # the authors' implementation
28 | # ablation studies have been performed and these levels and their
29 | # weights have been changed (refer experimental/hcl_experiments.py)
30 | levels = [h, 4, 2, 1]
31 | level_weight = [1.0, 0.5, 0.25, 0.125]
32 | total_weight = sum(level_weight)
33 | # keeping the corresponding weights same, we change the levels to:
34 | # [h, 2, 1], [4, 1], [h, h//2, h//4], [h, h-1, h-2, h-3]
35 | # keeping the levels same, we change the weights to:
36 | # [1.0, 1.0, 1.0, 1.0], [0.125, 0.25, 0.5, 1.0]
37 |
38 | for lvl, lvl_weight in zip(levels, level_weight):
39 | if lvl > h:
40 | continue
41 |
42 | lvl_sf = F.adaptive_avg_pool2d(student_features, (lvl, lvl))
43 | lvl_tf = F.adaptive_avg_pool2d(teacher_features, (lvl, lvl))
44 |
45 | lvl_loss = F.mse_loss(lvl_sf, lvl_tf) * lvl_weight
46 | loss += lvl_loss
47 |
48 | return loss / total_weight
49 |
50 |
51 | ########## Attention Based Fusion Module ##########
52 |
53 | # Paper states that the output from the ABF module (single output as
54 | # presented in the ABF flow diagram, fig. 3(a)) is the one of the inputs to
55 | # the next ABF module.
56 |
57 | # But the authors' code implementation provides two different outputs, one that
58 | # proceeds to the next ABF module (`residual_output`) and one that
59 | # is the output of the ABF module and which is involved in the loss
60 | # function (`abf_output`)
61 | # The `residual_output` differs from the `abf_output` in terms of the number
62 | # of channels. The `residual_output` has `mid_channels` while the `abf_output`
63 | # has `out_channels`
64 |
65 | # In this implementation, we have taken the latter approach
66 |
67 | # The second approach can be found in experimental/abf_experiments.py
68 |
69 | class ABF(nn.Module):
70 | def __init__(self, in_channel, out_channel):
71 | super(ABF, self).__init__()
72 |
73 | self.mid_channel = 64
74 |
75 | self.conv_to_mid_channel = nn.Sequential(
76 | nn.Conv2d(in_channel, self.mid_channel, kernel_size=1, bias=False),
77 | nn.BatchNorm2d(self.mid_channel),
78 | ).to(torch.device('cuda:0'))
79 | nn.init.kaiming_uniform_(self.conv_to_mid_channel[0].weight, a=1)
80 |
81 | self.conv_to_out_channel = nn.Sequential(
82 | nn.Conv2d(self.mid_channel, out_channel, kernel_size=3,
83 | stride=1, padding=1, bias=False),
84 | nn.BatchNorm2d(out_channel),
85 | ).to(torch.device('cuda:0'))
86 | nn.init.kaiming_uniform_(self.conv_to_out_channel[0].weight, a=1)
87 |
88 | self.conv_to_att_maps = nn.Sequential(
89 | nn.Conv2d(self.mid_channel * 2, 2, kernel_size=1),
90 | nn.Sigmoid(),
91 | ).to(torch.device('cuda:0'))
92 | nn.init.kaiming_uniform_(self.conv_to_att_maps[0].weight, a=1)
93 |
94 | def forward(self, student_feature, prev_abf_output, teacher_shape):
95 | n, c, h, w = student_feature.shape
96 | student_feature = self.conv_to_mid_channel(student_feature)
97 |
98 | if prev_abf_output is None:
99 | residual_output = student_feature
100 | else:
101 | prev_abf_output = F.interpolate(prev_abf_output, size=(
102 | teacher_shape, teacher_shape), mode='nearest')
103 |
104 | concat_features = torch.cat(
105 | [student_feature, prev_abf_output], dim=1)
106 | attention_maps = self.conv_to_att_maps(concat_features)
107 | attention_map1 = attention_maps[:, 0].view(n, 1, h, w)
108 | attention_map2 = attention_maps[:, 1].view(n, 1, h, w)
109 |
110 | residual_output = student_feature * attention_map1 \
111 | + prev_abf_output * attention_map2
112 |
113 | # the output of the abf is obtained after the residual
114 | # output is convolved to have `out_channels` channels
115 | abf_output = self.conv_to_out_channel(residual_output)
116 |
117 | return abf_output, residual_output
118 |
119 |
120 | ########## Residual Learning Framework ##########
121 |
122 | class RLF_for_Resnet(nn.Module):
123 | def __init__(self, student, abf_to_use):
124 | super(RLF_for_Resnet, self).__init__()
125 |
126 | self.student = student
127 |
128 | in_channels = [16, 32, 64, 64]
129 | out_channels = [16, 32, 64, 64]
130 |
131 | self.shapes = [1, 8, 16, 32, 32]
132 |
133 | ABFs = nn.ModuleList()
134 |
135 | for idx, in_channel in enumerate(in_channels):
136 | ABFs.append(abf_to_use(in_channel, out_channels[idx]))
137 |
138 | self.ABFs = ABFs[::-1]
139 | self.to('cuda')
140 |
141 | def forward(self, x):
142 | student_features = self.student(x, is_feat=True)
143 |
144 | student_preds = student_features[1]
145 | student_features = student_features[0][::-1]
146 |
147 | results = []
148 |
149 | abf_output, residual_output = self.ABFs[0](
150 | student_features[0], None, self.shapes[0])
151 |
152 | results.append(abf_output)
153 |
154 | for features, abf, shape in zip(student_features[1:], self.ABFs[1:], self.shapes[1:]):
155 | # here we use a recursive technique to obtain all the ABF
156 | # outputs and store them in a list
157 | abf_output, residual_output = abf(features, residual_output, shape)
158 | results.insert(0, abf_output)
159 |
160 | return results, student_preds
161 |
--------------------------------------------------------------------------------
/experimental/abf_experiments.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '..')
3 |
4 |
5 | from train import train
6 | from params import params
7 | from hcl_experiments import *
8 | from abf_experiments import *
9 | from framework import ABF, hcl
10 | import torch
11 | import torch.nn.functional as F
12 | from torch import nn
13 |
14 |
15 | class ABF_without_mid_channels(nn.Module):
16 | def __init__(self, in_channel, out_channel, pabf_channel):
17 | super(ABF_without_mid_channels, self).__init__()
18 |
19 | self.conv_to_out_channel_sf = nn.Sequential(
20 | nn.Conv2d(in_channel, out_channel, kernel_size=3,
21 | stride=1, padding=1, bias=False),
22 | nn.BatchNorm2d(out_channel),
23 | ).to(torch.device('cuda:0'))
24 | nn.init.kaiming_uniform_(self.conv_to_out_channel_sf[0].weight, a=1)
25 |
26 | self.conv_to_out_channel_pabf = nn.Sequential(
27 | nn.Conv2d(pabf_channel, out_channel, kernel_size=3,
28 | stride=1, padding=1, bias=False),
29 | nn.BatchNorm2d(out_channel),
30 | ).to(torch.device('cuda:0'))
31 | nn.init.kaiming_uniform_(self.conv_to_out_channel_pabf[0].weight, a=1)
32 |
33 | self.conv_to_att_maps = nn.Sequential(
34 | nn.Conv2d(out_channel * 2, 2, kernel_size=1),
35 | nn.Sigmoid(),
36 | ).to(torch.device('cuda:0'))
37 | nn.init.kaiming_uniform_(self.conv_to_att_maps[0].weight, a=1)
38 |
39 | def forward(self, student_feature, prev_abf_output, teacher_shape):
40 | n, c, h, w = student_feature.shape
41 | student_feature = self.conv_to_out_channel_sf(student_feature)
42 |
43 | if prev_abf_output is None:
44 | residual_output = student_feature
45 | else:
46 | print(prev_abf_output.shape)
47 | prev_abf_output = self.conv_to_out_channel_pabf(prev_abf_output)
48 | prev_abf_output = F.interpolate(prev_abf_output, size=(
49 | teacher_shape, teacher_shape), mode='nearest')
50 |
51 | concat_features = torch.cat(
52 | [student_feature, prev_abf_output], dim=1)
53 | attention_maps = self.conv_to_att_maps(concat_features)
54 | attention_map1 = attention_maps[:, 0].view(n, 1, h, w)
55 | attention_map2 = attention_maps[:, 1].view(n, 1, h, w)
56 |
57 | residual_output = student_feature * attention_map1 \
58 | + prev_abf_output * attention_map2
59 |
60 | # here we just equate both the outputs instead of having
61 | # a single output to have the same training code for both
62 | # the implementations
63 | abf_output = residual_output
64 |
65 | return abf_output, residual_output
66 |
67 |
68 | class ABF_without_attention_maps(nn.Module):
69 | def __init__(self, in_channel, out_channel):
70 | super(ABF_without_attention_maps, self).__init__()
71 |
72 | self.mid_channel = 64
73 |
74 | self.conv_to_mid_channel = nn.Sequential(
75 | nn.Conv2d(in_channel, self.mid_channel, kernel_size=1, bias=False),
76 | nn.BatchNorm2d(self.mid_channel),
77 | ).to(torch.device('cuda:0'))
78 | nn.init.kaiming_uniform_(self.conv_to_mid_channel[0].weight, a=1)
79 |
80 | self.conv_to_out_channel = nn.Sequential(
81 | nn.Conv2d(self.mid_channel, out_channel, kernel_size=3,
82 | stride=1, padding=1, bias=False),
83 | nn.BatchNorm2d(out_channel),
84 | ).to(torch.device('cuda:0'))
85 | nn.init.kaiming_uniform_(self.conv_to_out_channel[0].weight, a=1)
86 |
87 | self.conv_to_att_maps = nn.Sequential(
88 | nn.Conv2d(self.mid_channel * 2, 2, kernel_size=1),
89 | nn.Sigmoid(),
90 | ).to(torch.device('cuda:0'))
91 | nn.init.kaiming_uniform_(self.conv_to_att_maps[0].weight, a=1)
92 |
93 | def forward(self, student_feature, prev_abf_output, teacher_shape):
94 | n, c, h, w = student_feature.shape
95 | student_feature = self.conv_to_mid_channel(student_feature)
96 |
97 | if prev_abf_output is None:
98 | residual_output = student_feature
99 | else:
100 | prev_abf_output = F.interpolate(prev_abf_output, size=(
101 | teacher_shape, teacher_shape), mode='nearest')
102 |
103 | residual_output = student_feature + prev_abf_output
104 |
105 | # the output of the abf is obtained after the residual
106 | # output is convolved to have `out_channels` channels
107 | abf_output = self.conv_to_out_channel(residual_output)
108 |
109 | return abf_output, residual_output
110 |
111 |
112 | class RLF_for_Resnet_with_ABF_without_mid_channels(nn.Module):
113 | def __init__(self, student, abf_to_use):
114 | super(RLF_for_Resnet_with_ABF_without_mid_channels, self).__init__()
115 |
116 | self.student = student
117 |
118 | in_channels = [16, 32, 64, 64]
119 | out_channels = [16, 32, 64, 64]
120 | pabf_channels = [1, 16, 32, 64]
121 |
122 | self.shapes = [1, 8, 16, 32, 32]
123 |
124 | ABFs = nn.ModuleList()
125 |
126 | for idx, in_channel in enumerate(in_channels):
127 | ABFs.append(abf_to_use(in_channel, out_channels[idx], pabf_channels[idx]))
128 |
129 | self.ABFs = ABFs[::-1]
130 | self.to('cuda')
131 |
132 | def forward(self, x):
133 | student_features = self.student(x, is_feat=True)
134 |
135 | student_preds = student_features[1]
136 | student_features = student_features[0][::-1]
137 |
138 | results = []
139 |
140 | abf_output, residual_output = self.ABFs[0](
141 | student_features[0], None, self.shapes[0])
142 |
143 | results.append(abf_output)
144 |
145 | for features, abf, shape in zip(student_features[1:], self.ABFs[1:], self.shapes[1:]):
146 | # here we use a recursive technique to obtain all the ABF
147 | # outputs and store them in a list
148 | abf_output, residual_output = abf(features, residual_output, shape)
149 | results.insert(0, abf_output)
150 |
151 | return results, student_preds
152 |
153 |
154 | if __name__ == '__main__':
155 | train(params, hcl, ABF_without_attention_maps,
156 | log_file_suffix='abf_without_attention_maps')
157 | params['lr'] = 0.1
158 | train(params, hcl, ABF_without_mid_channels, RLF_for_Resnet_with_ABF_without_mid_channels,
159 | log_file_suffix='abf_without_mid_channels')
160 |
--------------------------------------------------------------------------------
/utils/resnets_for_cifar.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import math
13 |
14 |
15 | __all__ = ['resnet']
16 |
17 |
18 | def conv3x3(in_planes, out_planes, stride=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=1, bias=False)
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False):
28 | super(BasicBlock, self).__init__()
29 | self.is_last = is_last
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = nn.BatchNorm2d(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = nn.BatchNorm2d(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | preact = out
53 | out = F.relu(out)
54 | if self.is_last:
55 | return out, preact
56 | else:
57 | return out
58 |
59 |
60 | class Bottleneck(nn.Module):
61 | expansion = 4
62 |
63 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False):
64 | super(Bottleneck, self).__init__()
65 | self.is_last = is_last
66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
67 | self.bn1 = nn.BatchNorm2d(planes)
68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
69 | padding=1, bias=False)
70 | self.bn2 = nn.BatchNorm2d(planes)
71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
72 | self.bn3 = nn.BatchNorm2d(planes * 4)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | residual = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | residual = self.downsample(x)
93 |
94 | out += residual
95 | preact = out
96 | out = F.relu(out)
97 | if self.is_last:
98 | return out, preact
99 | else:
100 | return out
101 |
102 |
103 | class ResNet(nn.Module):
104 |
105 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10):
106 | super(ResNet, self).__init__()
107 | # Model type specifies number of layers for CIFAR-10 model
108 | if block_name.lower() == 'basicblock':
109 | assert (
110 | depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
111 | n = (depth - 2) // 6
112 | block = BasicBlock
113 | elif block_name.lower() == 'bottleneck':
114 | assert (
115 | depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
116 | n = (depth - 2) // 9
117 | block = Bottleneck
118 | else:
119 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
120 |
121 | self.inplanes = num_filters[0]
122 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1,
123 | bias=False)
124 | self.bn1 = nn.BatchNorm2d(num_filters[0])
125 | self.relu = nn.ReLU(inplace=True)
126 | self.layer1 = self._make_layer(block, num_filters[1], n)
127 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2)
128 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2)
129 | self.avgpool = nn.AvgPool2d(8)
130 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes)
131 |
132 | for m in self.modules():
133 | if isinstance(m, nn.Conv2d):
134 | nn.init.kaiming_normal_(
135 | m.weight, mode='fan_out', nonlinearity='relu')
136 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
137 | nn.init.constant_(m.weight, 1)
138 | nn.init.constant_(m.bias, 0)
139 | self.to('cuda')
140 |
141 | def _make_layer(self, block, planes, blocks, stride=1):
142 | downsample = None
143 | if stride != 1 or self.inplanes != planes * block.expansion:
144 | downsample = nn.Sequential(
145 | nn.Conv2d(self.inplanes, planes * block.expansion,
146 | kernel_size=1, stride=stride, bias=False),
147 | nn.BatchNorm2d(planes * block.expansion),
148 | )
149 |
150 | layers = list([])
151 | layers.append(block(self.inplanes, planes, stride,
152 | downsample, is_last=(blocks == 1)))
153 | self.inplanes = planes * block.expansion
154 | for i in range(1, blocks):
155 | layers.append(block(self.inplanes, planes,
156 | is_last=(i == blocks-1)))
157 |
158 | return nn.Sequential(*layers)
159 |
160 | def get_feat_modules(self):
161 | feat_m = nn.ModuleList([])
162 | feat_m.append(self.conv1)
163 | feat_m.append(self.bn1)
164 | feat_m.append(self.relu)
165 | feat_m.append(self.layer1)
166 | feat_m.append(self.layer2)
167 | feat_m.append(self.layer3)
168 | return feat_m
169 |
170 | def get_bn_before_relu(self):
171 | if isinstance(self.layer1[0], Bottleneck):
172 | bn1 = self.layer1[-1].bn3
173 | bn2 = self.layer2[-1].bn3
174 | bn3 = self.layer3[-1].bn3
175 | elif isinstance(self.layer1[0], BasicBlock):
176 | bn1 = self.layer1[-1].bn2
177 | bn2 = self.layer2[-1].bn2
178 | bn3 = self.layer3[-1].bn2
179 | else:
180 | raise NotImplementedError('ResNet unknown block error !!!')
181 |
182 | return [bn1, bn2, bn3]
183 |
184 | def forward(self, x, is_feat=False, preact=False):
185 | x = self.conv1(x)
186 | x = self.bn1(x)
187 | x = self.relu(x) # 32x32
188 | f0 = x
189 |
190 | x, f1_pre = self.layer1(x) # 32x32
191 | f1 = x
192 | x, f2_pre = self.layer2(x) # 16x16
193 | f2 = x
194 | x, f3_pre = self.layer3(x) # 8x8
195 | f3 = x
196 |
197 | x = self.avgpool(x)
198 | f4 = x
199 | x = x.view(x.size(0), -1)
200 | x = self.fc(x)
201 |
202 | if is_feat:
203 | if preact:
204 | return [f0, f1_pre, f2_pre, f3_pre, f4], x
205 | else:
206 | return [f0, f1, f2, f3, f4], x
207 | else:
208 | return x
209 |
--------------------------------------------------------------------------------
/experimental/table7_experiments.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '..')
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | import torch.backends.cudnn as cudnn
8 | import numpy as np
9 | import time
10 |
11 | from params import params
12 | from data import get_dataloaders
13 | from teachers import get_teacher
14 | from students import get_student
15 | from abf_experiments import ABF_without_attention_maps
16 | from framework import RLF_for_Resnet
17 | from utils.misc import AverageMeter, format_time, Logger
18 | from test import test
19 |
20 |
21 | def train_general(params, framework, kd_loss, log_file_suffix=''):
22 | cudnn.deterministic = True
23 | cudnn.benchmark = False
24 | if params["seed"] == 0:
25 | params["seed"] = np.random.randint(1000)
26 | torch.manual_seed(params["seed"])
27 | np.random.seed(params["seed"])
28 | torch.cuda.manual_seed(params["seed"])
29 |
30 | train_loader, test_loader = get_dataloaders(
31 | params["dataset"], params["batch_size"])
32 | if params["dataset"] == 'cifar10':
33 | num_classes = 10
34 | elif params["dataset"] == 'cifar100':
35 | num_classes = 100
36 |
37 | teacher = get_teacher(params["teacher"], num_classes=num_classes)
38 | student = get_student(params["student"], num_classes=num_classes)
39 |
40 | # load teacher weights from pretrained model
41 | # weight = torch.load(params["teacher_weight_path"])
42 | # teacher.load_state_dict(weight)
43 | # for p in teacher.parameters():
44 | # p.requires_grad = False
45 | teacher.to(torch.device('cuda:0'))
46 |
47 | base_loss = nn.CrossEntropyLoss().cuda()
48 | optimizer = torch.optim.SGD(
49 | framework.parameters(),
50 | lr=params["lr"],
51 | momentum=0.9,
52 | nesterov=True,
53 | weight_decay=params["weight_decay"]
54 | )
55 |
56 | train_log_file = f"logs/{params['dataset'] + '_' + params['student'] + '_' + params['teacher'] + '_' + log_file_suffix}"
57 | logger = Logger(params=params, filename=train_log_file+'.txt')
58 | best_accuracy = 0.0
59 | best_model = framework
60 |
61 | start_time = time.time()
62 | print("starting training with the following params:")
63 | print(params)
64 | print()
65 |
66 | for epoch in range(params["num_epochs"]):
67 | loss_avg = {
68 | 'kd_loss': AverageMeter(),
69 | 'base_loss': AverageMeter()
70 | }
71 | correct_preds = 0.0
72 | total_images = 0.0
73 |
74 | for i, (X, y) in enumerate(train_loader):
75 | X, y = X.cuda(), y.cuda()
76 |
77 | losses = {"kd_loss": 0, "base_loss": 0}
78 |
79 | # getting student and teacher features
80 | # authors use features obtained **after** activation for the student
81 | # (see rlf implementation in framework.py)
82 | student_features, student_preds = framework(X)
83 | # authors use features obtained **before** activation for the teacher (preact=True)
84 | teacher_features, teacher_preds = teacher(X, is_feat=True, preact=True)
85 |
86 | # authors start from the second teacher features
87 | if isinstance(framework, RLF_for_Resnet):
88 | teacher_features = teacher_features[1:]
89 |
90 | # calculating review kd loss
91 | for sf, tf in zip(student_features, teacher_features):
92 | losses['kd_loss'] += kd_loss(sf, tf)
93 |
94 | # calculating cross entropy loss
95 | losses['base_loss'] = base_loss(student_preds, y)
96 |
97 | loss = losses['kd_loss'] * params['kd_loss_weight']
98 | loss += losses['base_loss']
99 |
100 | optimizer.zero_grad()
101 | loss.backward()
102 | optimizer.step()
103 |
104 | for key in losses:
105 | loss_avg[key].update(losses[key])
106 |
107 | # calculate running average of accuracy
108 | student_preds = torch.max(student_preds.data, 1)[1]
109 | total_images += y.size(0)
110 | correct_preds += (student_preds == y.data).sum().item()
111 | train_accuracy = correct_preds / total_images
112 |
113 | # calculating test accuracy and storing best results
114 | test_accuracy = test(framework, test_loader)
115 | if test_accuracy > best_accuracy:
116 | best_accuracy = test_accuracy
117 | best_model = framework
118 |
119 | # decaying lr at scheduled steps
120 | if epoch in params['lr_decay_steps']:
121 | params['lr'] *= params["lr_decay_rate"]
122 | for param_group in optimizer.param_groups:
123 | param_group['lr'] = params['lr']
124 |
125 | # logging results
126 | loss_avg = {k: loss_avg[k].val for k in loss_avg}
127 | log_row = {
128 | 'epoch': str(epoch),
129 | 'train_acc': '%.2f' % (train_accuracy*100),
130 | 'test_acc': '%.2f' % (test_accuracy*100),
131 | 'best_acc': '%.2f' % (best_accuracy*100),
132 | 'lr': '%.5f' % (params['lr']),
133 | 'loss': '%.5f' % (sum(loss_avg.values())),
134 | 'kd_loss': '%.5f' % loss_avg['kd_loss'],
135 | 'base_loss': '%.5f' % loss_avg['base_loss'],
136 | 'time': format_time(time.time()-start_time),
137 | 'eta': format_time((time.time()-start_time)/(epoch+1)*(params["num_epochs"]-epoch-1)),
138 | }
139 | print(log_row)
140 | logger.writerow(log_row)
141 |
142 | torch.save(best_model.state_dict(), 'pretrained/' + train_log_file + '.pt')
143 | logger.close()
144 |
145 |
146 | class GeneralFusionModule(nn.Module):
147 | def __init__(self, in_channels, out_channel, out_shape):
148 | super(GeneralFusionModule, self).__init__()
149 | self.out_shape = out_shape
150 |
151 | conv_layers = []
152 | for ch in in_channels:
153 | conv = nn.Sequential(
154 | nn.Conv2d(ch, out_channel, kernel_size=1),
155 | nn.BatchNorm2d(out_channel)
156 | ).to(torch.device('cuda:0'))
157 | nn.init.kaiming_uniform_(conv[0].weight, a=1)
158 | conv_layers.append(conv)
159 |
160 | self.conv_layers = conv_layers
161 |
162 | def forward(self, student_features):
163 | for i, sf in enumerate(student_features):
164 | sf = self.conv_layers[i](sf)
165 | student_features[i] = F.interpolate(
166 | sf, size=(self.out_shape, self.out_shape), mode='nearest')
167 |
168 | output = student_features[0]
169 | for sf in student_features[1:]:
170 | output += sf
171 |
172 | return output
173 |
174 |
175 | class BaselineFramework(nn.Module):
176 | def __init__(self, student):
177 | super(BaselineFramework, self).__init__()
178 |
179 | in_channels = [16, 16, 32, 64, 64]
180 | out_channels = [16, 16, 32, 64, 64]
181 | out_shapes = [32, 32, 16, 8, 1]
182 |
183 | fusion_modules = nn.ModuleList()
184 |
185 | for i in range(len(in_channels)):
186 | fusion_modules.append(
187 | GeneralFusionModule([in_channels[i], ], out_channels[i], out_shapes[i]))
188 |
189 | self.fusion_modules = fusion_modules
190 |
191 | self.student = student
192 | self.to('cuda')
193 |
194 | def forward(self, x):
195 | student_features = self.student(x, is_feat=True)
196 |
197 | student_preds = student_features[1]
198 | student_features = student_features[0]
199 |
200 | results = []
201 |
202 | for i, fm in enumerate(self.fusion_modules):
203 | results.append(fm([student_features[i], ]))
204 |
205 | return results, student_preds
206 |
207 |
208 | def train_baseline(params, student):
209 | framework = BaselineFramework(student)
210 | return train_general(params, framework, F.mse_loss, log_file_suffix='table7_1')
211 |
212 |
213 | class RMFramework(nn.Module):
214 | def __init__(self, student):
215 | super(RMFramework, self).__init__()
216 |
217 | self.student = student
218 |
219 | in_channels = [16, 16, 32, 64, 64]
220 | out_channels = [16, 16, 32, 64, 64]
221 | out_shapes = [32, 32, 16, 8, 1]
222 |
223 | fusion_modules = nn.ModuleList()
224 |
225 | for i in range(len(in_channels)):
226 | fusion_modules.append(
227 | GeneralFusionModule(in_channels[i:], out_channels[i], out_shapes[i]))
228 |
229 | self.fusion_modules = fusion_modules
230 | self.to('cuda')
231 |
232 | def forward(self, x):
233 | student_features = self.student(x, is_feat=True)
234 |
235 | student_preds = student_features[1]
236 | student_features = student_features[0]
237 |
238 | results = []
239 |
240 | for i, fm in enumerate(self.fusion_modules):
241 | results.append(fm(student_features[i:]))
242 |
243 | return results, student_preds
244 |
245 |
246 | def train_rm_framework(params, student):
247 | framework = RMFramework(student)
248 | return train_general(params, framework, F.mse_loss, log_file_suffix='table7_2')
249 |
250 |
251 | def train_rlf_framework(params, student):
252 | framework = RLF_for_Resnet(student, ABF_without_attention_maps)
253 | return train_general(params, framework, F.mse_loss, log_file_suffix='table7_3')
254 |
255 |
256 | if __name__ == '__main__':
257 | student = get_student(params["student"], num_classes=10)
258 | train_rlf_framework(params, student)
259 | params['lr'] = 0.1; train_rm_framework(params, student)
260 | params['lr'] = 0.1; train_baseline(params, student)
--------------------------------------------------------------------------------