├── fig └── data_scale_bar.png ├── losses ├── __init__.py ├── bkd.py ├── correlation.py ├── crd.py ├── dist.py ├── dkd.py ├── kd.py ├── review.py └── rkd.py ├── models ├── __init__.py ├── beit.py └── convnextv2.py ├── object_detection ├── README.md ├── configs │ ├── convnextv2 │ │ ├── cascade_mask_rcnn_convnextv2_3x_coco.py │ │ └── mask_rcnn_convnextv2_3x_coco.py │ └── resnet │ │ └── mask_rcnn_r50_1x_coco.py └── mmdet │ └── models │ └── backbones │ ├── __init__.py │ └── convnextv2.py ├── readme.md ├── register.py ├── requirements.txt ├── train-crd.py ├── train-fd.py ├── train-kd.py ├── utils.py └── validate.py /fig/data_scale_bar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hao840/vanillaKD/a27b029f693a13d67b85c61a52e25a85f9414e70/fig/data_scale_bar.png -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The losses come from DIST and DKD 3 | https://github.com/megvii-research/mdistiller 4 | https://github.com/hunto/DIST_KD 5 | 6 | Modifications by Zhiwei Hao (haozhw@bit.edu.cn) and Jianyuan Guo (jianyuan_guo@outlook.com) 7 | ''' 8 | 9 | from .bkd import BinaryKLDiv 10 | from .correlation import Correlation 11 | from .crd import CRD 12 | from .dist import DIST 13 | from .dkd import DKD 14 | from .kd import KLDiv 15 | from .review import ReviewKD 16 | from .rkd import RKD 17 | -------------------------------------------------------------------------------- /losses/bkd.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BinaryKLDiv(nn.Module): 6 | def __init__(self): 7 | super(BinaryKLDiv, self).__init__() 8 | 9 | def forward(self, z_s, z_t, **kwargs): 10 | kd_loss = F.binary_cross_entropy_with_logits(z_s, z_t.softmax(1)) 11 | return kd_loss 12 | -------------------------------------------------------------------------------- /losses/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Correlation(nn.Module): 6 | scale = 0.02 7 | def __init__(self, feat_s_channel, feat_t_channel, feat_dim=128): 8 | super(Correlation, self).__init__() 9 | self.embed_s = LinearEmbed(feat_s_channel, feat_dim) 10 | self.embed_t = LinearEmbed(feat_t_channel, feat_dim) 11 | 12 | def forward(self, z_s, z_t, **kwargs): 13 | f_s = self.embed_s(kwargs['feature_student'][-1]) 14 | f_t = self.embed_t(kwargs['feature_teacher'][-1]) 15 | 16 | delta = torch.abs(f_s - f_t) 17 | kd_loss = self.scale * torch.mean((delta[:-1] * delta[1:]).sum(1)) 18 | return kd_loss 19 | 20 | 21 | class LinearEmbed(nn.Module): 22 | def __init__(self, dim_in=1024, dim_out=128): 23 | super(LinearEmbed, self).__init__() 24 | self.linear = nn.Linear(dim_in, dim_out) 25 | 26 | def forward(self, x): 27 | x = x.view(x.shape[0], -1) 28 | x = self.linear(x) 29 | return x 30 | -------------------------------------------------------------------------------- /losses/crd.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CRD(nn.Module): 8 | def __init__(self, feat_s_channel, feat_t_channel, feat_dim, num_data, k=16384, momentum=0.5, temperature=0.07): 9 | super(CRD, self).__init__() 10 | self.embed_s = Embed(feat_s_channel, feat_dim) 11 | self.embed_t = Embed(feat_t_channel, feat_dim) 12 | self.contrast = ContrastMemory(feat_dim, num_data, k, temperature, momentum) 13 | self.criterion_s = ContrastLoss(num_data) 14 | self.criterion_t = ContrastLoss(num_data) 15 | 16 | def forward(self, z_s, z_t, **kwargs): 17 | f_s = self.embed_s(kwargs['feature_student'][-1]) 18 | f_t = self.embed_t(kwargs['feature_teacher'][-1]) 19 | out_s, out_t = self.contrast(f_s, f_t, kwargs['index'], kwargs['contrastive_index']) 20 | s_loss = self.criterion_s(out_s) 21 | t_loss = self.criterion_t(out_t) 22 | kd_loss = s_loss + t_loss 23 | return kd_loss 24 | 25 | def get_learnable_parameters(self): 26 | return ( 27 | super().get_learnable_parameters() 28 | + list(self.embed_s.parameters()) 29 | + list(self.embed_t.parameters()) 30 | ) 31 | 32 | def get_extra_parameters(self): 33 | params = ( 34 | list(self.embed_s.parameters()) 35 | + list(self.embed_t.parameters()) 36 | + list(self.contrast.buffers()) 37 | ) 38 | num_p = 0 39 | for p in params: 40 | num_p += p.numel() 41 | return num_p 42 | 43 | 44 | class Normalize(nn.Module): 45 | """normalization layer""" 46 | 47 | def __init__(self, power=2): 48 | super(Normalize, self).__init__() 49 | self.power = power 50 | 51 | def forward(self, x): 52 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1.0 / self.power) 53 | out = x.div(norm) 54 | return out 55 | 56 | 57 | class Embed(nn.Module): 58 | """Embedding module""" 59 | 60 | def __init__(self, dim_in=1024, dim_out=128): 61 | super(Embed, self).__init__() 62 | self.linear = nn.Linear(dim_in, dim_out) 63 | self.l2norm = Normalize(2) 64 | 65 | def forward(self, x): 66 | x = x.reshape(x.shape[0], -1) 67 | x = self.linear(x) 68 | x = self.l2norm(x) 69 | return x 70 | 71 | 72 | class ContrastLoss(nn.Module): 73 | """contrastive loss""" 74 | 75 | def __init__(self, num_data): 76 | super(ContrastLoss, self).__init__() 77 | self.num_data = num_data 78 | 79 | def forward(self, x): 80 | eps = 1e-7 81 | bsz = x.shape[0] 82 | m = x.size(1) - 1 83 | 84 | # noise distribution 85 | Pn = 1 / float(self.num_data) 86 | 87 | # loss for positive pair 88 | P_pos = x.select(1, 0) 89 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_() 90 | 91 | # loss for K negative pair 92 | P_neg = x.narrow(1, 1, m) 93 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_() 94 | 95 | loss = -(log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz 96 | 97 | return loss 98 | 99 | 100 | class ContrastMemory(nn.Module): 101 | """memory buffer that supplies large amount of negative samples.""" 102 | 103 | def __init__(self, inputSize, output_size, K, T=0.07, momentum=0.5): 104 | super(ContrastMemory, self).__init__() 105 | self.n_lem = output_size 106 | self.unigrams = torch.ones(self.n_lem) 107 | self.multinomial = AliasMethod(self.unigrams) 108 | self.K = K 109 | 110 | self.register_buffer("params", torch.tensor([K, T, -1, -1, momentum])) 111 | stdv = 1.0 / math.sqrt(inputSize / 3) 112 | self.register_buffer( 113 | "memory_v1", torch.rand(output_size, inputSize).mul_(2 * stdv).add_(-stdv) 114 | ) 115 | self.register_buffer( 116 | "memory_v2", torch.rand(output_size, inputSize).mul_(2 * stdv).add_(-stdv) 117 | ) 118 | 119 | def cuda(self, *args, **kwargs): 120 | super(ContrastMemory, self).cuda(*args, **kwargs) 121 | self.multinomial.cuda() 122 | 123 | def forward(self, v1, v2, y, idx=None): 124 | K = int(self.params[0].item()) 125 | T = self.params[1].item() 126 | Z_v1 = self.params[2].item() 127 | Z_v2 = self.params[3].item() 128 | 129 | momentum = self.params[4].item() 130 | batchSize = v1.size(0) 131 | outputSize = self.memory_v1.size(0) 132 | inputSize = self.memory_v1.size(1) 133 | 134 | # original score computation 135 | if idx is None: 136 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1) 137 | idx.select(1, 0).copy_(y.data) 138 | # sample 139 | weight_v1 = torch.index_select(self.memory_v1, 0, idx.view(-1)).detach() 140 | weight_v1 = weight_v1.view(batchSize, K + 1, inputSize) 141 | out_v2 = torch.bmm(weight_v1, v2.view(batchSize, inputSize, 1)) 142 | out_v2 = torch.exp(torch.div(out_v2, T)) 143 | # sample 144 | weight_v2 = torch.index_select(self.memory_v2, 0, idx.view(-1)).detach() 145 | weight_v2 = weight_v2.view(batchSize, K + 1, inputSize) 146 | out_v1 = torch.bmm(weight_v2, v1.view(batchSize, inputSize, 1)) 147 | out_v1 = torch.exp(torch.div(out_v1, T)) 148 | 149 | # set Z if haven't been set yet 150 | if Z_v1 < 0: 151 | self.params[2] = out_v1.mean() * outputSize 152 | Z_v1 = self.params[2].clone().detach().item() 153 | # print("normalization constant Z_v1 is set to {:.1f}".format(Z_v1)) 154 | if Z_v2 < 0: 155 | self.params[3] = out_v2.mean() * outputSize 156 | Z_v2 = self.params[3].clone().detach().item() 157 | # print("normalization constant Z_v2 is set to {:.1f}".format(Z_v2)) 158 | 159 | # compute out_v1, out_v2 160 | out_v1 = torch.div(out_v1, Z_v1).contiguous() 161 | out_v2 = torch.div(out_v2, Z_v2).contiguous() 162 | 163 | # update memory 164 | with torch.no_grad(): 165 | l_pos = torch.index_select(self.memory_v1, 0, y.view(-1)) 166 | l_pos.mul_(momentum) 167 | l_pos.add_(torch.mul(v1, 1 - momentum)) 168 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) 169 | updated_v1 = l_pos.div(l_norm) 170 | self.memory_v1.index_copy_(0, y, updated_v1) 171 | 172 | ab_pos = torch.index_select(self.memory_v2, 0, y.view(-1)) 173 | ab_pos.mul_(momentum) 174 | ab_pos.add_(torch.mul(v2, 1 - momentum)) 175 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) 176 | updated_v2 = ab_pos.div(ab_norm) 177 | self.memory_v2.index_copy_(0, y, updated_v2) 178 | 179 | return out_v1, out_v2 180 | 181 | 182 | class AliasMethod(object): 183 | """ 184 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 185 | """ 186 | 187 | def __init__(self, probs): 188 | 189 | if probs.sum() > 1: 190 | probs.div_(probs.sum()) 191 | K = len(probs) 192 | self.prob = torch.zeros(K) 193 | self.alias = torch.LongTensor([0] * K) 194 | 195 | # Sort the data into the outcomes with probabilities 196 | # that are larger and smaller than 1/K. 197 | smaller = [] 198 | larger = [] 199 | for kk, prob in enumerate(probs): 200 | self.prob[kk] = K * prob 201 | if self.prob[kk] < 1.0: 202 | smaller.append(kk) 203 | else: 204 | larger.append(kk) 205 | 206 | # Loop though and create little binary mixtures that 207 | # appropriately allocate the larger outcomes over the 208 | # overall uniform mixture. 209 | while len(smaller) > 0 and len(larger) > 0: 210 | small = smaller.pop() 211 | large = larger.pop() 212 | 213 | self.alias[small] = large 214 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] 215 | 216 | if self.prob[large] < 1.0: 217 | smaller.append(large) 218 | else: 219 | larger.append(large) 220 | 221 | for last_one in smaller + larger: 222 | self.prob[last_one] = 1 223 | 224 | def cuda(self): 225 | self.prob = self.prob.cuda() 226 | self.alias = self.alias.cuda() 227 | 228 | def draw(self, N): 229 | """Draw N samples from multinomial""" 230 | K = self.alias.size(0) 231 | 232 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K) 233 | prob = self.prob.index_select(0, kk) 234 | alias = self.alias.index_select(0, kk) 235 | # b is whether a random number is greater than q 236 | b = torch.bernoulli(prob) 237 | oq = kk.mul(b.long()) 238 | oj = alias.mul((1 - b).long()) 239 | 240 | return oq + oj 241 | -------------------------------------------------------------------------------- /losses/dist.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def cosine_similarity(a, b, eps=1e-8): 5 | return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps) 6 | 7 | 8 | def pearson_correlation(a, b, eps=1e-8): 9 | return cosine_similarity(a - a.mean(1).unsqueeze(1), 10 | b - b.mean(1).unsqueeze(1), eps) 11 | 12 | 13 | def inter_class_relation(y_s, y_t): 14 | return 1 - pearson_correlation(y_s, y_t).mean() 15 | 16 | 17 | def intra_class_relation(y_s, y_t): 18 | return inter_class_relation(y_s.transpose(0, 1), y_t.transpose(0, 1)) 19 | 20 | 21 | class DIST(nn.Module): 22 | def __init__(self, beta=1.0, gamma=1.0, tau=1.0): 23 | super(DIST, self).__init__() 24 | self.beta = beta 25 | self.gamma = gamma 26 | self.tau = tau 27 | 28 | def forward(self, z_s, z_t, **kwargs): 29 | y_s = (z_s / self.tau).softmax(dim=1) 30 | y_t = (z_t / self.tau).softmax(dim=1) 31 | inter_loss = self.tau ** 2 * inter_class_relation(y_s, y_t) 32 | intra_loss = self.tau ** 2 * intra_class_relation(y_s, y_t) 33 | kd_loss = self.beta * inter_loss + self.gamma * intra_loss 34 | return kd_loss 35 | -------------------------------------------------------------------------------- /losses/dkd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature): 7 | gt_mask = _get_gt_mask(logits_student, target) 8 | other_mask = _get_other_mask(logits_student, target) 9 | pred_student = F.softmax(logits_student / temperature, dim=1) 10 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 11 | pred_student = cat_mask(pred_student, gt_mask, other_mask) 12 | pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask) 13 | log_pred_student = torch.log(pred_student) 14 | tckd_loss = ( 15 | F.kl_div(log_pred_student, pred_teacher, reduction='batchmean') 16 | * (temperature ** 2) 17 | ) 18 | pred_teacher_part2 = F.softmax( 19 | logits_teacher / temperature - 1000.0 * gt_mask, dim=1 20 | ) 21 | log_pred_student_part2 = F.log_softmax( 22 | logits_student / temperature - 1000.0 * gt_mask, dim=1 23 | ) 24 | nckd_loss = ( 25 | F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean') 26 | * (temperature ** 2) 27 | ) 28 | return alpha * tckd_loss + beta * nckd_loss 29 | 30 | 31 | def _get_gt_mask(logits, target): 32 | target = target.reshape(-1) 33 | mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool() 34 | return mask 35 | 36 | 37 | def _get_other_mask(logits, target): 38 | target = target.reshape(-1) 39 | mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool() 40 | return mask 41 | 42 | 43 | def cat_mask(t, mask1, mask2): 44 | t1 = (t * mask1).sum(dim=1, keepdims=True) 45 | t2 = (t * mask2).sum(1, keepdims=True) 46 | rt = torch.cat([t1, t2], dim=1) 47 | return rt 48 | 49 | 50 | class DKD(nn.Module): 51 | def __init__(self, alpha=1., beta=2., temperature=1.): 52 | super(DKD, self).__init__() 53 | self.alpha = alpha 54 | self.beta = beta 55 | self.temperature = temperature 56 | 57 | def forward(self, z_s, z_t, **kwargs): 58 | target = kwargs['target'] 59 | if len(target.shape) == 2: # mixup / smoothing 60 | target = target.max(1)[1] 61 | kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature) 62 | return kd_loss 63 | -------------------------------------------------------------------------------- /losses/kd.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class KLDiv(nn.Module): 6 | def __init__(self, temperature=1.0): 7 | super(KLDiv, self).__init__() 8 | self.temperature = temperature 9 | 10 | def forward(self, z_s, z_t, **kwargs): 11 | log_pred_student = F.log_softmax(z_s / self.temperature, dim=1) 12 | pred_teacher = F.softmax(z_t / self.temperature, dim=1) 13 | kd_loss = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean() 14 | kd_loss *= self.temperature ** 2 15 | return kd_loss 16 | -------------------------------------------------------------------------------- /losses/review.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ReviewKD(nn.Module): 7 | pre_act_feat = True 8 | def __init__(self, feat_index_s, feat_index_t, in_channels, out_channels, 9 | shapes=(1, 7, 14, 28, 56), out_shapes=(1, 7, 14, 28, 56), 10 | warmup_epochs=1, max_mid_channel=512): 11 | super(ReviewKD, self).__init__() 12 | self.feat_index_s = feat_index_s 13 | self.feat_index_t = feat_index_t 14 | self.shapes = shapes 15 | self.out_shapes = out_shapes 16 | self.warmup_epochs = warmup_epochs 17 | 18 | abfs = nn.ModuleList() 19 | mid_channel = min(max_mid_channel, in_channels[-1]) 20 | for idx, in_channel in enumerate(in_channels): 21 | abfs.append(ABF(in_channel, mid_channel, out_channels[idx], idx < len(in_channels) - 1)) 22 | 23 | self.abfs = abfs[::-1] 24 | 25 | def forward(self, z_s, z_t, **kwargs): 26 | f_s = [kwargs['feature_student'][i] for i in self.feat_index_s] 27 | pre_logit_feat_s = kwargs['feature_student'][-1] 28 | if len(pre_logit_feat_s.shape) == 2: 29 | pre_logit_feat_s = pre_logit_feat_s.unsqueeze(-1).unsqueeze(-1) 30 | f_s.append(pre_logit_feat_s) 31 | 32 | f_s = f_s[::-1] 33 | results = [] 34 | out_features, res_features = self.abfs[0](f_s[0], out_shape=self.out_shapes[0]) 35 | results.append(out_features) 36 | for features, abf, shape, out_shape in zip(f_s[1:], self.abfs[1:], self.shapes[1:], self.out_shapes[1:]): 37 | out_features, res_features = abf(features, res_features, shape, out_shape) 38 | results.insert(0, out_features) 39 | 40 | f_t = [kwargs['feature_teacher'][i] for i in self.feat_index_t] 41 | pre_logit_feat_t = kwargs['feature_teacher'][-1] 42 | if len(pre_logit_feat_t.shape) == 2: 43 | pre_logit_feat_t = pre_logit_feat_t.unsqueeze(-1).unsqueeze(-1) 44 | f_t.append(pre_logit_feat_t) 45 | 46 | kd_loss = min(kwargs["epoch"] / self.warmup_epochs, 1.0) * hcl_loss(results, f_t) 47 | 48 | return kd_loss 49 | 50 | 51 | class ABF(nn.Module): 52 | def __init__(self, in_channel, mid_channel, out_channel, fuse): 53 | super(ABF, self).__init__() 54 | self.conv1 = nn.Sequential( 55 | nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False), 56 | nn.BatchNorm2d(mid_channel), 57 | ) 58 | self.conv2 = nn.Sequential( 59 | nn.Conv2d( 60 | mid_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False 61 | ), 62 | nn.BatchNorm2d(out_channel), 63 | ) 64 | if fuse: 65 | self.att_conv = nn.Sequential( 66 | nn.Conv2d(mid_channel * 2, 2, kernel_size=1), 67 | nn.Sigmoid(), 68 | ) 69 | else: 70 | self.att_conv = None 71 | nn.init.kaiming_uniform_(self.conv1[0].weight, a=1) # pyre-ignore 72 | nn.init.kaiming_uniform_(self.conv2[0].weight, a=1) # pyre-ignore 73 | 74 | def forward(self, x, y=None, shape=None, out_shape=None): 75 | n, _, h, w = x.shape 76 | # transform student features 77 | x = self.conv1(x) 78 | if self.att_conv is not None: 79 | # upsample residual features 80 | y = F.interpolate(y, (shape, shape), mode="nearest") 81 | # fusion 82 | z = torch.cat([x, y], dim=1) 83 | z = self.att_conv(z) 84 | x = x * z[:, 0].view(n, 1, h, w) + y * z[:, 1].view(n, 1, h, w) 85 | # output 86 | if x.shape[-1] != out_shape: 87 | x = F.interpolate(x, (out_shape, out_shape), mode="nearest") 88 | y = self.conv2(x) 89 | return y, x 90 | 91 | 92 | def hcl_loss(fstudent, fteacher): 93 | loss_all = 0.0 94 | for fs, ft in zip(fstudent, fteacher): 95 | n, c, h, w = fs.shape 96 | loss = F.mse_loss(fs, ft, reduction="mean") 97 | cnt = 1.0 98 | tot = 1.0 99 | for l in [4, 2, 1]: 100 | if l >= h: 101 | continue 102 | tmpfs = F.adaptive_avg_pool2d(fs, (l, l)) 103 | tmpft = F.adaptive_avg_pool2d(ft, (l, l)) 104 | cnt /= 2.0 105 | loss += F.mse_loss(tmpfs, tmpft, reduction="mean") * cnt 106 | tot += cnt 107 | loss = loss / tot 108 | loss_all = loss_all + loss 109 | return loss_all 110 | -------------------------------------------------------------------------------- /losses/rkd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def _pdist(e, squared, eps): 7 | e_square = e.pow(2).sum(dim=1) 8 | prod = e @ e.t() 9 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 10 | 11 | if not squared: 12 | res = res.sqrt() 13 | 14 | res = res.clone() 15 | res[range(len(e)), range(len(e))] = 0 16 | return res 17 | 18 | 19 | class RKD(nn.Module): 20 | def __init__(self, distance_weight=25, angle_weight=50, eps=1e-12, squared=False): 21 | super(RKD, self).__init__() 22 | self.distance_weight = distance_weight 23 | self.angle_weight = angle_weight 24 | self.eps = eps 25 | self.squared = squared 26 | 27 | def forward(self, z_s, z_t, **kwargs): 28 | f_s = kwargs['feature_student'][-1] 29 | f_t = kwargs['feature_teacher'][-1] 30 | 31 | stu = f_s.view(f_s.shape[0], -1) 32 | tea = f_t.view(f_t.shape[0], -1) 33 | 34 | # RKD distance loss 35 | with torch.no_grad(): 36 | t_d = _pdist(tea, self.squared, self.eps) 37 | mean_td = t_d[t_d > 0].mean() 38 | t_d = t_d / mean_td 39 | 40 | d = _pdist(stu, self.squared, self.eps) 41 | mean_d = d[d > 0].mean() 42 | d = d / mean_d 43 | 44 | loss_d = F.smooth_l1_loss(d, t_d) 45 | 46 | # RKD Angle loss 47 | with torch.no_grad(): 48 | td = tea.unsqueeze(0) - tea.unsqueeze(1) 49 | norm_td = F.normalize(td, p=2, dim=2) 50 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) 51 | 52 | sd = stu.unsqueeze(0) - stu.unsqueeze(1) 53 | norm_sd = F.normalize(sd, p=2, dim=2) 54 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) 55 | 56 | loss_a = F.smooth_l1_loss(s_angle, t_angle) 57 | 58 | kd_loss = self.distance_weight * loss_d + self.angle_weight * loss_a 59 | return kd_loss 60 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .beit import * 2 | from .convnextv2 import * -------------------------------------------------------------------------------- /models/beit.py: -------------------------------------------------------------------------------- 1 | """ BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 2 | 3 | Model from official source: https://github.com/microsoft/unilm/tree/master/beit 4 | 5 | At this point only the 1k fine-tuned classification weights and model configs have been added, 6 | see original source above for pre-training models and procedure. 7 | 8 | Modifications by / Copyright 2021 Ross Wightman, original copyrights below 9 | """ 10 | # -------------------------------------------------------- 11 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 12 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 13 | # Copyright (c) 2021 Microsoft 14 | # Licensed under The MIT License [see LICENSE for details] 15 | # By Hangbo Bao 16 | # Based on timm and DeiT code bases 17 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 18 | # https://github.com/facebookresearch/deit/ 19 | # https://github.com/facebookresearch/dino 20 | # --------------------------------------------------------' 21 | import math 22 | from functools import partial 23 | from typing import Optional, Tuple 24 | 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | from torch.utils.checkpoint import checkpoint 29 | 30 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 31 | from timm.models.helpers import build_model_with_cfg 32 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ 33 | from timm.models.registry import register_model 34 | from timm.models.vision_transformer import checkpoint_filter_fn 35 | 36 | 37 | def _cfg(url='', **kwargs): 38 | return { 39 | 'url': url, 40 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 41 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 42 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 43 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 44 | **kwargs 45 | } 46 | 47 | 48 | default_cfgs = { 49 | 'beit_base_patch16_224': _cfg( 50 | url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), 51 | 'beit_base_patch16_384': _cfg( 52 | url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', 53 | input_size=(3, 384, 384), crop_pct=1.0, 54 | ), 55 | 'beit_base_patch16_224_in22k': _cfg( 56 | url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth', 57 | num_classes=21841, 58 | ), 59 | 'beit_large_patch16_224': _cfg( 60 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), 61 | 'beit_large_patch16_384': _cfg( 62 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', 63 | input_size=(3, 384, 384), crop_pct=1.0, 64 | ), 65 | 'beit_large_patch16_512': _cfg( 66 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', 67 | input_size=(3, 512, 512), crop_pct=1.0, 68 | ), 69 | 'beit_large_patch16_224_in22k': _cfg( 70 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth', 71 | num_classes=21841, 72 | ), 73 | 74 | 'beitv2_base_patch16_224': _cfg( 75 | url='', input_size=(3, 224, 224), crop_pct=0.9, 76 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD 77 | ), 78 | 'beitv2_large_patch16_224': _cfg( 79 | url='', input_size=(3, 224, 224), crop_pct=0.95, 80 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD 81 | ), 82 | } 83 | 84 | 85 | def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: 86 | num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 87 | # cls to token & token 2 cls & cls to cls 88 | # get pair-wise relative position index for each token inside the window 89 | window_area = window_size[0] * window_size[1] 90 | coords = torch.stack(torch.meshgrid( 91 | [torch.arange(window_size[0]), 92 | torch.arange(window_size[1])])) # 2, Wh, Ww 93 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 94 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 95 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 96 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 97 | relative_coords[:, :, 1] += window_size[1] - 1 98 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 99 | relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) 100 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 101 | relative_position_index[0, 0:] = num_relative_distance - 3 102 | relative_position_index[0:, 0] = num_relative_distance - 2 103 | relative_position_index[0, 0] = num_relative_distance - 1 104 | return relative_position_index 105 | 106 | 107 | class Attention(nn.Module): 108 | def __init__( 109 | self, dim, num_heads=8, qkv_bias=False, attn_drop=0., 110 | proj_drop=0., window_size=None, attn_head_dim=None): 111 | super().__init__() 112 | self.num_heads = num_heads 113 | head_dim = dim // num_heads 114 | if attn_head_dim is not None: 115 | head_dim = attn_head_dim 116 | all_head_dim = head_dim * self.num_heads 117 | self.scale = head_dim ** -0.5 118 | 119 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 120 | if qkv_bias: 121 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 122 | self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) 123 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 124 | else: 125 | self.q_bias = None 126 | self.k_bias = None 127 | self.v_bias = None 128 | 129 | if window_size: 130 | self.window_size = window_size 131 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 132 | self.relative_position_bias_table = nn.Parameter( 133 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 134 | self.register_buffer("relative_position_index", gen_relative_position_index(window_size)) 135 | else: 136 | self.window_size = None 137 | self.relative_position_bias_table = None 138 | self.relative_position_index = None 139 | 140 | self.attn_drop = nn.Dropout(attn_drop) 141 | self.proj = nn.Linear(all_head_dim, dim) 142 | self.proj_drop = nn.Dropout(proj_drop) 143 | 144 | def _get_rel_pos_bias(self): 145 | relative_position_bias = self.relative_position_bias_table[ 146 | self.relative_position_index.view(-1)].view( 147 | self.window_size[0] * self.window_size[1] + 1, 148 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH 149 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 150 | return relative_position_bias.unsqueeze(0) 151 | 152 | def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): 153 | B, N, C = x.shape 154 | 155 | qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None 156 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 157 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 158 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 159 | 160 | q = q * self.scale 161 | attn = (q @ k.transpose(-2, -1)) 162 | 163 | if self.relative_position_bias_table is not None: 164 | attn = attn + self._get_rel_pos_bias() 165 | if shared_rel_pos_bias is not None: 166 | attn = attn + shared_rel_pos_bias 167 | 168 | attn = attn.softmax(dim=-1) 169 | attn = self.attn_drop(attn) 170 | 171 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 172 | x = self.proj(x) 173 | x = self.proj_drop(x) 174 | return x 175 | 176 | 177 | class Block(nn.Module): 178 | 179 | def __init__( 180 | self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 181 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 182 | window_size=None, attn_head_dim=None): 183 | super().__init__() 184 | self.norm1 = norm_layer(dim) 185 | self.attn = Attention( 186 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, 187 | window_size=window_size, attn_head_dim=attn_head_dim) 188 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 189 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 190 | self.norm2 = norm_layer(dim) 191 | mlp_hidden_dim = int(dim * mlp_ratio) 192 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 193 | 194 | if init_values: 195 | self.gamma_1 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True) 196 | self.gamma_2 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True) 197 | else: 198 | self.gamma_1, self.gamma_2 = None, None 199 | 200 | def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): 201 | if self.gamma_1 is None: 202 | x = x + self.drop_path(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) 203 | x = x + self.drop_path(self.mlp(self.norm2(x))) 204 | else: 205 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) 206 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 207 | return x 208 | 209 | 210 | class RelativePositionBias(nn.Module): 211 | 212 | def __init__(self, window_size, num_heads): 213 | super().__init__() 214 | self.window_size = window_size 215 | self.window_area = window_size[0] * window_size[1] 216 | num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 217 | self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) 218 | # trunc_normal_(self.relative_position_bias_table, std=.02) 219 | self.register_buffer("relative_position_index", gen_relative_position_index(window_size)) 220 | 221 | def forward(self): 222 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 223 | self.window_area + 1, self.window_area + 1, -1) # Wh*Ww,Wh*Ww,nH 224 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 225 | 226 | 227 | class Beit(nn.Module): 228 | """ Vision Transformer with support for patch or hybrid CNN input stage 229 | """ 230 | 231 | def __init__( 232 | self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', 233 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., 234 | attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), 235 | init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, 236 | head_init_scale=0.001): 237 | super().__init__() 238 | self.num_classes = num_classes 239 | self.global_pool = global_pool 240 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 241 | self.grad_checkpointing = False 242 | 243 | self.patch_embed = PatchEmbed( 244 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 245 | num_patches = self.patch_embed.num_patches 246 | 247 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 248 | # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 249 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None 250 | self.pos_drop = nn.Dropout(p=drop_rate) 251 | 252 | if use_shared_rel_pos_bias: 253 | self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.grid_size, num_heads=num_heads) 254 | else: 255 | self.rel_pos_bias = None 256 | 257 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 258 | self.blocks = nn.ModuleList([ 259 | Block( 260 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 261 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 262 | init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None) 263 | for i in range(depth)]) 264 | use_fc_norm = self.global_pool == 'avg' 265 | self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) 266 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else None 267 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 268 | 269 | self.apply(self._init_weights) 270 | if self.pos_embed is not None: 271 | trunc_normal_(self.pos_embed, std=.02) 272 | trunc_normal_(self.cls_token, std=.02) 273 | # trunc_normal_(self.mask_token, std=.02) 274 | self.fix_init_weight() 275 | if isinstance(self.head, nn.Linear): 276 | trunc_normal_(self.head.weight, std=.02) 277 | self.head.weight.data.mul_(head_init_scale) 278 | self.head.bias.data.mul_(head_init_scale) 279 | 280 | def fix_init_weight(self): 281 | def rescale(param, layer_id): 282 | param.div_(math.sqrt(2.0 * layer_id)) 283 | 284 | for layer_id, layer in enumerate(self.blocks): 285 | rescale(layer.attn.proj.weight.data, layer_id + 1) 286 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 287 | 288 | def _init_weights(self, m): 289 | if isinstance(m, nn.Linear): 290 | trunc_normal_(m.weight, std=.02) 291 | if isinstance(m, nn.Linear) and m.bias is not None: 292 | nn.init.constant_(m.bias, 0) 293 | elif isinstance(m, nn.LayerNorm): 294 | nn.init.constant_(m.bias, 0) 295 | nn.init.constant_(m.weight, 1.0) 296 | 297 | @torch.jit.ignore 298 | def no_weight_decay(self): 299 | nwd = {'pos_embed', 'cls_token'} 300 | for n, _ in self.named_parameters(): 301 | if 'relative_position_bias_table' in n: 302 | nwd.add(n) 303 | return nwd 304 | 305 | @torch.jit.ignore 306 | def set_grad_checkpointing(self, enable=True): 307 | self.grad_checkpointing = enable 308 | 309 | @torch.jit.ignore 310 | def group_matcher(self, coarse=False): 311 | matcher = dict( 312 | stem=r'^cls_token|pos_embed|patch_embed|rel_pos_bias', # stem and embed 313 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))], 314 | ) 315 | return matcher 316 | 317 | @torch.jit.ignore 318 | def get_classifier(self): 319 | return self.head 320 | 321 | def reset_classifier(self, num_classes, global_pool=None): 322 | self.num_classes = num_classes 323 | if global_pool is not None: 324 | self.global_pool = global_pool 325 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 326 | 327 | def forward_features(self, x): 328 | x = self.patch_embed(x) 329 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 330 | if self.pos_embed is not None: 331 | x = x + self.pos_embed 332 | x = self.pos_drop(x) 333 | 334 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 335 | for blk in self.blocks: 336 | if self.grad_checkpointing and not torch.jit.is_scripting(): 337 | x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) 338 | else: 339 | x = blk(x, shared_rel_pos_bias=rel_pos_bias) 340 | x = self.norm(x) 341 | return x 342 | 343 | def forward_head(self, x, pre_logits: bool = False): 344 | if self.fc_norm is not None: 345 | x = x[:, 1:].mean(dim=1) 346 | x = self.fc_norm(x) 347 | else: 348 | x = x[:, 0] 349 | return x if pre_logits else self.head(x) 350 | 351 | def forward(self, x): 352 | x = self.forward_features(x) 353 | x = self.forward_head(x) 354 | return x 355 | 356 | 357 | def _create_beit(variant, pretrained=False, **kwargs): 358 | if kwargs.get('features_only', None): 359 | raise RuntimeError('features_only not implemented for Beit models.') 360 | 361 | model = build_model_with_cfg( 362 | Beit, variant, pretrained, 363 | # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes 364 | pretrained_filter_fn=checkpoint_filter_fn, 365 | **kwargs) 366 | return model 367 | 368 | 369 | @register_model 370 | def beit_base_patch16_224(pretrained=False, **kwargs): 371 | model_kwargs = dict( 372 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 373 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) 374 | model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **model_kwargs) 375 | return model 376 | 377 | 378 | @register_model 379 | def beit_base_patch16_384(pretrained=False, **kwargs): 380 | model_kwargs = dict( 381 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 382 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) 383 | model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **model_kwargs) 384 | return model 385 | 386 | 387 | @register_model 388 | def beit_base_patch16_224_in22k(pretrained=False, **kwargs): 389 | model_kwargs = dict( 390 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 391 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) 392 | model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) 393 | return model 394 | 395 | 396 | @register_model 397 | def beit_large_patch16_224(pretrained=False, **kwargs): 398 | model_kwargs = dict( 399 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 400 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 401 | model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **model_kwargs) 402 | return model 403 | 404 | 405 | @register_model 406 | def beit_large_patch16_384(pretrained=False, **kwargs): 407 | model_kwargs = dict( 408 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 409 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 410 | model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **model_kwargs) 411 | return model 412 | 413 | 414 | @register_model 415 | def beit_large_patch16_512(pretrained=False, **kwargs): 416 | model_kwargs = dict( 417 | img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 418 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 419 | model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **model_kwargs) 420 | return model 421 | 422 | 423 | @register_model 424 | def beit_large_patch16_224_in22k(pretrained=False, **kwargs): 425 | model_kwargs = dict( 426 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 427 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 428 | model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) 429 | return model 430 | 431 | 432 | # models with suffix "_1k" are duplicated to load different checkpoint 433 | @register_model 434 | def beitv2_base_patch16_224_1k(pretrained=False, **kwargs): 435 | model_kwargs = dict( 436 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 437 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 438 | model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs) 439 | return model 440 | 441 | 442 | @register_model 443 | def beitv2_large_patch16_224_1k(pretrained=False, **kwargs): 444 | model_kwargs = dict( 445 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 446 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 447 | model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) 448 | return model 449 | 450 | 451 | @register_model 452 | def beitv2_base_patch16_224(pretrained=False, **kwargs): 453 | model_kwargs = dict( 454 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 455 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 456 | model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs) 457 | return model 458 | 459 | 460 | @register_model 461 | def beitv2_large_patch16_224(pretrained=False, **kwargs): 462 | model_kwargs = dict( 463 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 464 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 465 | model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) 466 | return model -------------------------------------------------------------------------------- /models/convnextv2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | # from MinkowskiEngine import SparseTensor 10 | # Copyright (c) Meta Platforms, Inc. and affiliates. 11 | 12 | # All rights reserved. 13 | 14 | # This source code is licensed under the license found in the 15 | # LICENSE file in the root directory of this source tree. 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from timm.models.layers import DropPath, trunc_normal_ 22 | from timm.models.registry import register_model 23 | 24 | 25 | class LayerNorm(nn.Module): 26 | """ LayerNorm that supports two data formats: channels_last (default) or channels_first. 27 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 28 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 29 | with shape (batch_size, channels, height, width). 30 | """ 31 | 32 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 35 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 36 | self.eps = eps 37 | self.data_format = data_format 38 | if self.data_format not in ["channels_last", "channels_first"]: 39 | raise NotImplementedError 40 | self.normalized_shape = (normalized_shape,) 41 | 42 | def forward(self, x): 43 | if self.data_format == "channels_last": 44 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 45 | elif self.data_format == "channels_first": 46 | u = x.mean(1, keepdim=True) 47 | s = (x - u).pow(2).mean(1, keepdim=True) 48 | x = (x - u) / torch.sqrt(s + self.eps) 49 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 50 | return x 51 | 52 | 53 | class GRN(nn.Module): 54 | """ GRN (Global Response Normalization) layer 55 | """ 56 | 57 | def __init__(self, dim): 58 | super().__init__() 59 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 60 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 61 | 62 | def forward(self, x): 63 | Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) 64 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 65 | return self.gamma * (x * Nx) + self.beta + x 66 | 67 | 68 | class Block(nn.Module): 69 | """ ConvNeXtV2 Block. 70 | 71 | Args: 72 | dim (int): Number of input channels. 73 | drop_path (float): Stochastic depth rate. Default: 0.0 74 | """ 75 | 76 | def __init__(self, dim, drop_path=0.): 77 | super().__init__() 78 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 79 | self.norm = LayerNorm(dim, eps=1e-6) 80 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 81 | self.act = nn.GELU() 82 | self.grn = GRN(4 * dim) 83 | self.pwconv2 = nn.Linear(4 * dim, dim) 84 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 85 | 86 | def forward(self, x): 87 | input = x 88 | x = self.dwconv(x) 89 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 90 | x = self.norm(x) 91 | x = self.pwconv1(x) 92 | x = self.act(x) 93 | x = self.grn(x) 94 | x = self.pwconv2(x) 95 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 96 | 97 | x = input + self.drop_path(x) 98 | return x 99 | 100 | 101 | class ConvNeXtV2(nn.Module): 102 | """ ConvNeXt V2 103 | 104 | Args: 105 | in_chans (int): Number of input image channels. Default: 3 106 | num_classes (int): Number of classes for classification head. Default: 1000 107 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 108 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 109 | drop_path_rate (float): Stochastic depth rate. Default: 0. 110 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 111 | """ 112 | 113 | def __init__(self, in_chans=3, num_classes=1000, 114 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 115 | drop_path_rate=0., head_init_scale=1., **kwargs 116 | ): 117 | super().__init__() 118 | self.depths = depths 119 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 120 | stem = nn.Sequential( 121 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 122 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 123 | ) 124 | self.downsample_layers.append(stem) 125 | for i in range(3): 126 | downsample_layer = nn.Sequential( 127 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 128 | nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), 129 | ) 130 | self.downsample_layers.append(downsample_layer) 131 | 132 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 133 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 134 | cur = 0 135 | for i in range(4): 136 | stage = nn.Sequential( 137 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] 138 | ) 139 | self.stages.append(stage) 140 | cur += depths[i] 141 | 142 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 143 | self.head = nn.Linear(dims[-1], num_classes) 144 | 145 | self.apply(self._init_weights) 146 | self.head.weight.data.mul_(head_init_scale) 147 | self.head.bias.data.mul_(head_init_scale) 148 | 149 | def _init_weights(self, m): 150 | if isinstance(m, (nn.Conv2d, nn.Linear)): 151 | trunc_normal_(m.weight, std=.02) 152 | nn.init.constant_(m.bias, 0) 153 | 154 | def forward_features(self, x): 155 | for i in range(4): 156 | x = self.downsample_layers[i](x) 157 | x = self.stages[i](x) 158 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 159 | 160 | def forward(self, x): 161 | x = self.forward_features(x) 162 | x = self.head(x) 163 | return x 164 | 165 | 166 | @register_model 167 | def convnextv2_atto(**kwargs): 168 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) 169 | return model 170 | 171 | 172 | @register_model 173 | def convnextv2_femto(**kwargs): 174 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) 175 | return model 176 | 177 | 178 | @register_model 179 | def convnext_pico(**kwargs): 180 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) 181 | return model 182 | 183 | 184 | @register_model 185 | def convnextv2_nano(**kwargs): 186 | model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) 187 | return model 188 | 189 | 190 | @register_model 191 | def convnextv2_tiny(**kwargs): 192 | model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 193 | return model 194 | 195 | 196 | @register_model 197 | def convnextv2_base(**kwargs): 198 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 199 | return model 200 | 201 | 202 | @register_model 203 | def convnextv2_large(**kwargs): 204 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 205 | return model 206 | 207 | 208 | @register_model 209 | def convnextv2_huge(**kwargs): 210 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) 211 | return model 212 | -------------------------------------------------------------------------------- /object_detection/README.md: -------------------------------------------------------------------------------- 1 | # COCO Object detection 2 | 3 | ## Getting started 4 | 5 | We add VanillaNet model and config files based on [mmdetection-2.x](https://github.com/open-mmlab/mmdetection/tree/2.x). Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/2.x/docs/en/get_started.md) for mmdetection installation and dataset preparation instructions. 6 | 7 | ## Results and Fine-tuned Models 8 | 9 | | Framework | Backbone | LR Schedule | APb | APm | 10 | | :---------------: | :----------: | :---------: | :------------: | :------------: | 11 | | Mask RCNN | ResNet50 | 1x | 41.8 | 37.7 | 12 | | | ResNet50 | 2x | 42.1 | 38.0 | 13 | | Mask RCNN | ConvNeXtV2-T | 1x | 45.7 | 42.0 | 14 | | | ConvNeXtV2-T | 3x | 47.9 | 43.3 | 15 | | Cascade Mask RCNN | ConvNeXtV2-T | 1x | 50.6 | 44.3 | 16 | | | ConvNeXtV2-T | 3x | 52.1 | 45.4 | 17 | 18 | 19 | ### Training 20 | 21 | To train a model with 8 gpus, run: 22 | ``` 23 | python -m torch.distributed.launch --nproc_per_node=8 tools/train.py --gpus 8 --launcher pytorch --work-dir 24 | ``` 25 | 26 | 27 | ## Acknowledgment 28 | 29 | This code is built based on [mmdetection](https://github.com/open-mmlab/mmdetection), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) repositories. -------------------------------------------------------------------------------- /object_detection/configs/convnextv2/cascade_mask_rcnn_convnextv2_3x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/cascade_mask_rcnn_r50_fpn.py', 3 | '../_base_/datasets/coco_instance.py', 4 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' 5 | ] 6 | 7 | # you can download ckpt from: 8 | # https://github.com/Hao840/vanillaKD/releases/download/checkpoint/convnextv2_tiny-85.030.pth 9 | checkpoint_file = '/your_path_to/convnextv2_tiny-85.030.pth' 10 | 11 | model = dict( 12 | backbone=dict( 13 | _delete_=True, 14 | type='ConvNeXtV2', 15 | dims=[96, 192, 384, 768], 16 | drop_path_rate=0.4, 17 | layer_scale_init_value=0., 18 | init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)), 19 | neck=dict(in_channels=[96, 192, 384, 768]), 20 | roi_head=dict(bbox_head=[ 21 | dict( 22 | type='ConvFCBBoxHead', 23 | num_shared_convs=4, 24 | num_shared_fcs=1, 25 | in_channels=256, 26 | conv_out_channels=256, 27 | fc_out_channels=1024, 28 | roi_feat_size=7, 29 | num_classes=80, 30 | bbox_coder=dict( 31 | type='DeltaXYWHBBoxCoder', 32 | target_means=[0., 0., 0., 0.], 33 | target_stds=[0.1, 0.1, 0.2, 0.2]), 34 | reg_class_agnostic=False, 35 | reg_decoded_bbox=True, 36 | norm_cfg=dict(type='SyncBN', requires_grad=True), 37 | loss_cls=dict( 38 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 39 | loss_bbox=dict(type='GIoULoss', loss_weight=10.0)), 40 | dict( 41 | type='ConvFCBBoxHead', 42 | num_shared_convs=4, 43 | num_shared_fcs=1, 44 | in_channels=256, 45 | conv_out_channels=256, 46 | fc_out_channels=1024, 47 | roi_feat_size=7, 48 | num_classes=80, 49 | bbox_coder=dict( 50 | type='DeltaXYWHBBoxCoder', 51 | target_means=[0., 0., 0., 0.], 52 | target_stds=[0.05, 0.05, 0.1, 0.1]), 53 | reg_class_agnostic=False, 54 | reg_decoded_bbox=True, 55 | norm_cfg=dict(type='SyncBN', requires_grad=True), 56 | loss_cls=dict( 57 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 58 | loss_bbox=dict(type='GIoULoss', loss_weight=10.0)), 59 | dict( 60 | type='ConvFCBBoxHead', 61 | num_shared_convs=4, 62 | num_shared_fcs=1, 63 | in_channels=256, 64 | conv_out_channels=256, 65 | fc_out_channels=1024, 66 | roi_feat_size=7, 67 | num_classes=80, 68 | bbox_coder=dict( 69 | type='DeltaXYWHBBoxCoder', 70 | target_means=[0., 0., 0., 0.], 71 | target_stds=[0.033, 0.033, 0.067, 0.067]), 72 | reg_class_agnostic=False, 73 | reg_decoded_bbox=True, 74 | norm_cfg=dict(type='SyncBN', requires_grad=True), 75 | loss_cls=dict( 76 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 77 | loss_bbox=dict(type='GIoULoss', loss_weight=10.0)) 78 | ])) 79 | 80 | # dataset settings 81 | dataset_type = 'CocoDataset' 82 | data_root = 'data/coco/' 83 | img_norm_cfg = dict( 84 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 85 | train_pipeline = [ 86 | dict(type='LoadImageFromFile'), 87 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 88 | dict(type='RandomFlip', flip_ratio=0.5), 89 | dict( 90 | type='AutoAugment', 91 | policies=[[ 92 | dict( 93 | type='Resize', 94 | img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), 95 | (608, 1333), (640, 1333), (672, 1333), (704, 1333), 96 | (736, 1333), (768, 1333), (800, 1333)], 97 | multiscale_mode='value', 98 | keep_ratio=True) 99 | ], 100 | [ 101 | dict( 102 | type='Resize', 103 | img_scale=[(400, 1333), (500, 1333), (600, 1333)], 104 | multiscale_mode='value', 105 | keep_ratio=True), 106 | dict( 107 | type='RandomCrop', 108 | crop_type='absolute_range', 109 | crop_size=(384, 600), 110 | allow_negative_crop=True), 111 | dict( 112 | type='Resize', 113 | img_scale=[(480, 1333), (512, 1333), (544, 1333), 114 | (576, 1333), (608, 1333), (640, 1333), 115 | (672, 1333), (704, 1333), (736, 1333), 116 | (768, 1333), (800, 1333)], 117 | multiscale_mode='value', 118 | override=True, 119 | keep_ratio=True) 120 | ]]), 121 | dict(type='Normalize', **img_norm_cfg), 122 | dict(type='Pad', size_divisor=32), 123 | dict(type='DefaultFormatBundle'), 124 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 125 | ] 126 | test_pipeline = [ 127 | dict(type='LoadImageFromFile'), 128 | dict( 129 | type='MultiScaleFlipAug', 130 | img_scale=(1333, 800), 131 | flip=False, 132 | transforms=[ 133 | dict(type='Resize', keep_ratio=True), 134 | dict(type='RandomFlip'), 135 | dict(type='Normalize', **img_norm_cfg), 136 | dict(type='Pad', size_divisor=32), 137 | dict(type='ImageToTensor', keys=['img']), 138 | dict(type='Collect', keys=['img']), 139 | ]) 140 | ] 141 | data = dict( 142 | samples_per_gpu=2, 143 | workers_per_gpu=2, 144 | train=dict( 145 | type=dataset_type, 146 | ann_file=data_root + 'annotations/instances_train2017.json', 147 | img_prefix=data_root + 'train2017/', 148 | pipeline=train_pipeline), 149 | val=dict( 150 | type=dataset_type, 151 | ann_file=data_root + 'annotations/instances_val2017.json', 152 | img_prefix=data_root + 'val2017/', 153 | pipeline=test_pipeline), 154 | test=dict( 155 | type=dataset_type, 156 | ann_file=data_root + 'annotations/instances_val2017.json', 157 | img_prefix=data_root + 'val2017/', 158 | pipeline=test_pipeline)) 159 | evaluation = dict(metric=['bbox', 'segm']) 160 | 161 | optimizer = dict( 162 | _delete_=True, 163 | constructor='LearningRateDecayOptimizerConstructor', 164 | type='AdamW', 165 | lr=0.0002, 166 | betas=(0.9, 0.999), 167 | weight_decay=0.05, 168 | paramwise_cfg={ 169 | 'decay_rate': 0.7, 170 | 'decay_type': 'layer_wise', 171 | 'num_layers': 6 172 | }) 173 | 174 | lr_config = dict(warmup_iters=1000, step=[27, 33]) 175 | runner = dict(max_epochs=36) 176 | 177 | log_config = dict( 178 | interval=200, 179 | hooks=[ 180 | dict(type='TextLoggerHook'), 181 | # dict(type='TensorboardLoggerHook') 182 | ]) 183 | 184 | # you need to set mode='dynamic' if you are using pytorch<=1.5.0 185 | fp16 = dict(loss_scale=dict(init_scale=512)) 186 | -------------------------------------------------------------------------------- /object_detection/configs/convnextv2/mask_rcnn_convnextv2_3x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/mask_rcnn_r50_fpn.py', 3 | '../_base_/datasets/coco_instance.py', 4 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' 5 | ] 6 | 7 | # you can download ckpt from: 8 | # https://github.com/Hao840/vanillaKD/releases/download/checkpoint/convnextv2_tiny-85.030.pth 9 | checkpoint_file = '/your_path_to/convnextv2_tiny-85.030.pth' 10 | 11 | model = dict( 12 | backbone=dict( 13 | _delete_=True, 14 | type='ConvNeXtV2', 15 | dims=[96, 192, 384, 768], 16 | drop_path_rate=0.4, 17 | layer_scale_init_value=0., 18 | init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)), 19 | neck=dict(in_channels=[96, 192, 384, 768])) 20 | 21 | # dataset settings 22 | dataset_type = 'CocoDataset' 23 | data_root = 'data/coco/' 24 | img_norm_cfg = dict( 25 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 26 | train_pipeline = [ 27 | dict(type='LoadImageFromFile'), 28 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 29 | dict(type='RandomFlip', flip_ratio=0.5), 30 | dict( 31 | type='AutoAugment', 32 | policies=[[ 33 | dict( 34 | type='Resize', 35 | img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), 36 | (608, 1333), (640, 1333), (672, 1333), (704, 1333), 37 | (736, 1333), (768, 1333), (800, 1333)], 38 | multiscale_mode='value', 39 | keep_ratio=True) 40 | ], 41 | [ 42 | dict( 43 | type='Resize', 44 | img_scale=[(400, 1333), (500, 1333), (600, 1333)], 45 | multiscale_mode='value', 46 | keep_ratio=True), 47 | dict( 48 | type='RandomCrop', 49 | crop_type='absolute_range', 50 | crop_size=(384, 600), 51 | allow_negative_crop=True), 52 | dict( 53 | type='Resize', 54 | img_scale=[(480, 1333), (512, 1333), (544, 1333), 55 | (576, 1333), (608, 1333), (640, 1333), 56 | (672, 1333), (704, 1333), (736, 1333), 57 | (768, 1333), (800, 1333)], 58 | multiscale_mode='value', 59 | override=True, 60 | keep_ratio=True) 61 | ]]), 62 | dict(type='Normalize', **img_norm_cfg), 63 | dict(type='Pad', size_divisor=32), 64 | dict(type='DefaultFormatBundle'), 65 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 66 | ] 67 | test_pipeline = [ 68 | dict(type='LoadImageFromFile'), 69 | dict( 70 | type='MultiScaleFlipAug', 71 | img_scale=(1333, 800), 72 | flip=False, 73 | transforms=[ 74 | dict(type='Resize', keep_ratio=True), 75 | dict(type='RandomFlip'), 76 | dict(type='Normalize', **img_norm_cfg), 77 | dict(type='Pad', size_divisor=32), 78 | dict(type='ImageToTensor', keys=['img']), 79 | dict(type='Collect', keys=['img']), 80 | ]) 81 | ] 82 | data = dict( 83 | samples_per_gpu=2, 84 | workers_per_gpu=2, 85 | train=dict( 86 | type=dataset_type, 87 | ann_file=data_root + 'annotations/instances_train2017.json', 88 | img_prefix=data_root + 'train2017/', 89 | pipeline=train_pipeline), 90 | val=dict( 91 | type=dataset_type, 92 | ann_file=data_root + 'annotations/instances_val2017.json', 93 | img_prefix=data_root + 'val2017/', 94 | pipeline=test_pipeline), 95 | test=dict( 96 | type=dataset_type, 97 | ann_file=data_root + 'annotations/instances_val2017.json', 98 | img_prefix=data_root + 'val2017/', 99 | pipeline=test_pipeline)) 100 | evaluation = dict(metric=['bbox', 'segm']) 101 | 102 | optimizer = dict( 103 | _delete_=True, 104 | constructor='LearningRateDecayOptimizerConstructor', 105 | type='AdamW', 106 | lr=0.0001, 107 | betas=(0.9, 0.999), 108 | weight_decay=0.05, 109 | paramwise_cfg={ 110 | 'decay_rate': 0.95, 111 | 'decay_type': 'layer_wise', 112 | 'num_layers': 6 113 | }) 114 | 115 | lr_config = dict(warmup_iters=1000, step=[27, 33]) 116 | runner = dict(max_epochs=36) 117 | 118 | log_config = dict( 119 | interval=200, 120 | hooks=[ 121 | dict(type='TextLoggerHook'), 122 | # dict(type='TensorboardLoggerHook') 123 | ]) 124 | 125 | # you need to set mode='dynamic' if you are using pytorch<=1.5.0 126 | fp16 = dict(loss_scale=dict(init_scale=512)) 127 | -------------------------------------------------------------------------------- /object_detection/configs/resnet/mask_rcnn_r50_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/mask_rcnn_r50_fpn.py', 3 | '../_base_/datasets/coco_instance.py', 4 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' 5 | ] 6 | 7 | # you can download ckpt from: 8 | # https://github.com/Hao840/vanillaKD/releases/download/checkpoint/resnet50-83.078.pth 9 | checkpoint_file = '/your_path_to/resnet50-83.078.pth' 10 | 11 | model = dict( 12 | backbone=dict( 13 | _delete_=True, 14 | type='ResNet', 15 | depth=50, 16 | num_stages=4, 17 | out_indices=(0, 1, 2, 3), 18 | frozen_stages=1, 19 | norm_cfg=dict(type='SyncBN', requires_grad=True), 20 | norm_eval=False, 21 | style='pytorch', 22 | init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)), 23 | ) 24 | 25 | # dataset settings 26 | dataset_type = 'CocoDataset' 27 | data_root = 'data/coco/' 28 | img_norm_cfg = dict( 29 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 30 | train_pipeline = [ 31 | dict(type='LoadImageFromFile'), 32 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 33 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 34 | dict(type='RandomFlip', flip_ratio=0.5), 35 | dict(type='Normalize', **img_norm_cfg), 36 | dict(type='Pad', size_divisor=32), 37 | dict(type='DefaultFormatBundle'), 38 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 39 | ] 40 | test_pipeline = [ 41 | dict(type='LoadImageFromFile'), 42 | dict( 43 | type='MultiScaleFlipAug', 44 | img_scale=(1333, 800), 45 | flip=False, 46 | transforms=[ 47 | dict(type='Resize', keep_ratio=True), 48 | dict(type='RandomFlip'), 49 | dict(type='Normalize', **img_norm_cfg), 50 | dict(type='Pad', size_divisor=32), 51 | dict(type='ImageToTensor', keys=['img']), 52 | dict(type='Collect', keys=['img']), 53 | ]) 54 | ] 55 | data = dict( 56 | samples_per_gpu=4, 57 | workers_per_gpu=4, 58 | train=dict( 59 | type=dataset_type, 60 | ann_file=data_root + 'annotations/instances_train2017.json', 61 | img_prefix=data_root + 'train2017/', 62 | pipeline=train_pipeline), 63 | val=dict( 64 | type=dataset_type, 65 | ann_file=data_root + 'annotations/instances_val2017.json', 66 | img_prefix=data_root + 'val2017/', 67 | pipeline=test_pipeline), 68 | test=dict( 69 | type=dataset_type, 70 | ann_file=data_root + 'annotations/instances_val2017.json', 71 | img_prefix=data_root + 'val2017/', 72 | pipeline=test_pipeline)) 73 | evaluation = dict(metric=['bbox', 'segm']) 74 | 75 | # optimizer 76 | optimizer = dict(type='SGD', lr=0.04, momentum=0.9, weight_decay=0.0001) 77 | optimizer_config = dict(grad_clip=None) 78 | # learning policy 79 | lr_config = dict( 80 | policy='step', 81 | warmup='linear', 82 | warmup_iters=500, 83 | warmup_ratio=0.001, 84 | step=[8, 11]) 85 | runner = dict(type='EpochBasedRunner', max_epochs=12) 86 | 87 | log_config = dict( 88 | interval=200, 89 | hooks=[ 90 | dict(type='TextLoggerHook'), 91 | # dict(type='TensorboardLoggerHook') 92 | ]) -------------------------------------------------------------------------------- /object_detection/mmdet/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .darknet import Darknet 2 | from .detectors_resnet import DetectoRS_ResNet 3 | from .detectors_resnext import DetectoRS_ResNeXt 4 | from .hourglass import HourglassNet 5 | from .hrnet import HRNet 6 | from .regnet import RegNet 7 | from .res2net import Res2Net 8 | from .resnest import ResNeSt 9 | from .resnet import ResNet, ResNetV1d 10 | from .resnext import ResNeXt 11 | from .ssd_vgg import SSDVGG 12 | from .trident_resnet import TridentResNet 13 | from .swin_transformer import SwinTransformer 14 | from .convnextv2 import ConvNeXtV2 15 | 16 | __all__ = [ 17 | 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net', 18 | 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet', 19 | 'ResNeSt', 'TridentResNet', 'SwinTransformer', 'ConvNeXtV2' 20 | ] 21 | -------------------------------------------------------------------------------- /object_detection/mmdet/models/backbones/convnextv2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Modified by Jianyuan Guo, Zhiwei Hao 4 | 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import math 10 | 11 | from timm.models.layers import DropPath 12 | 13 | from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer, 14 | constant_init, normal_init, trunc_normal_init) 15 | from mmcv.cnn.bricks.drop import build_dropout 16 | from mmcv.cnn.utils.weight_init import trunc_normal_ 17 | from mmcv.runner import (BaseModule, ModuleList, Sequential, _load_checkpoint, 18 | load_state_dict) 19 | 20 | from ...utils import get_root_logger 21 | from ..builder import BACKBONES 22 | 23 | 24 | class LayerNorm2d(nn.LayerNorm): 25 | """LayerNorm on channels for 2d images. 26 | 27 | Args: 28 | num_channels (int): The number of channels of the input tensor. 29 | eps (float): a value added to the denominator for numerical stability. 30 | Defaults to 1e-5. 31 | elementwise_affine (bool): a boolean value that when set to ``True``, 32 | this module has learnable per-element affine parameters initialized 33 | to ones (for weights) and zeros (for biases). Defaults to True. 34 | """ 35 | 36 | def __init__(self, num_channels: int, **kwargs) -> None: 37 | super().__init__(num_channels, **kwargs) 38 | self.num_channels = self.normalized_shape[0] 39 | 40 | def forward(self, x, data_format='channel_first'): 41 | """Forward method. 42 | 43 | Args: 44 | x (torch.Tensor): The input tensor. 45 | data_format (str): The format of the input tensor. If 46 | ``"channel_first"``, the shape of the input tensor should be 47 | (B, C, H, W). If ``"channel_last"``, the shape of the input 48 | tensor should be (B, H, W, C). Defaults to "channel_first". 49 | """ 50 | assert x.dim() == 4, 'LayerNorm2d only supports inputs with shape ' \ 51 | f'(N, C, H, W), but got tensor with shape {x.shape}' 52 | if data_format == 'channel_last': 53 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, 54 | self.eps) 55 | elif data_format == 'channel_first': 56 | x = x.permute(0, 2, 3, 1) 57 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, 58 | self.eps) 59 | # If the output is discontiguous, it may cause some unexpected 60 | # problem in the downstream tasks 61 | x = x.permute(0, 3, 1, 2).contiguous() 62 | return x 63 | 64 | class LayerNorm(nn.Module): 65 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 66 | super().__init__() 67 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 68 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 69 | self.eps = eps 70 | self.data_format = data_format 71 | if self.data_format not in ["channels_last", "channels_first"]: 72 | raise NotImplementedError 73 | self.normalized_shape = (normalized_shape, ) 74 | 75 | def forward(self, x): 76 | if self.data_format == "channels_last": 77 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 78 | elif self.data_format == "channels_first": 79 | u = x.mean(1, keepdim=True) 80 | s = (x - u).pow(2).mean(1, keepdim=True) 81 | x = (x - u) / torch.sqrt(s + self.eps) 82 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 83 | return x 84 | 85 | class GRN(nn.Module): 86 | def __init__(self, dim): 87 | super().__init__() 88 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 89 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 90 | 91 | def forward(self, x): 92 | Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) 93 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 94 | return self.gamma * (x * Nx) + self.beta + x 95 | 96 | class Block(nn.Module): 97 | def __init__(self, dim, drop_path=0.): 98 | super().__init__() 99 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 100 | self.norm = LayerNorm(dim, eps=1e-6) 101 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 102 | self.act = nn.GELU() 103 | self.grn = GRN(4 * dim) 104 | self.pwconv2 = nn.Linear(4 * dim, dim) 105 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 106 | 107 | def forward(self, x): 108 | input = x 109 | x = self.dwconv(x) 110 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 111 | x = self.norm(x) 112 | x = self.pwconv1(x) 113 | x = self.act(x) 114 | x = self.grn(x) 115 | x = self.pwconv2(x) 116 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 117 | 118 | x = input + self.drop_path(x) 119 | return x 120 | 121 | 122 | @BACKBONES.register_module() 123 | class ConvNeXtV2(BaseModule): 124 | def __init__(self, in_chans=3, 125 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 126 | drop_path_rate=0., head_init_scale=1., 127 | out_indices=[0,1,2,3], init_cfg=None, **kwargs 128 | ): 129 | super().__init__() 130 | self.init_cfg = init_cfg 131 | self.depths = depths 132 | self.out_indices = out_indices 133 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 134 | stem = nn.Sequential( 135 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 136 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 137 | ) 138 | self.downsample_layers.append(stem) 139 | for i in range(3): 140 | downsample_layer = nn.Sequential( 141 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 142 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 143 | ) 144 | self.downsample_layers.append(downsample_layer) 145 | 146 | self.stages = nn.ModuleList() 147 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 148 | cur = 0 149 | for i in range(4): 150 | stage = nn.Sequential( 151 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] 152 | ) 153 | self.stages.append(stage) 154 | cur += depths[i] 155 | 156 | for i in out_indices: 157 | layer = LayerNorm2d(dims[i]) 158 | layer_name = f'norm{i}' 159 | self.add_module(layer_name, layer) 160 | 161 | def init_weights(self): 162 | if self.init_cfg is None: 163 | logger = get_root_logger() 164 | logger.warn(f'No pre-trained weights for ' 165 | f'{self.__class__.__name__}, ' 166 | f'training start from scratch') 167 | for m in self.modules(): 168 | if isinstance(m, nn.Linear): 169 | trunc_normal_init(m, std=.02, bias=0.) 170 | elif isinstance(m, nn.Conv2d): 171 | fan_out = m.kernel_size[0] * m.kernel_size[ 172 | 1] * m.out_channels 173 | fan_out //= m.groups 174 | normal_init(m, 0, math.sqrt(2.0 / fan_out)) 175 | else: 176 | state_dict = torch.load(self.init_cfg.checkpoint, map_location='cpu') 177 | if len(state_dict.keys()) <= 10 and 'model' in state_dict.keys(): 178 | msg = self.load_state_dict(state_dict['model'], strict=False) 179 | else: 180 | msg = self.load_state_dict(state_dict, strict=False) 181 | print(msg) 182 | print('Successfully load backbone ckpt.') 183 | 184 | def forward(self, x): 185 | 186 | outs = [] 187 | 188 | for i in range(4): 189 | x = self.downsample_layers[i](x) 190 | x = self.stages[i](x) 191 | if i in self.out_indices: 192 | norm_layer = getattr(self, f'norm{i}') 193 | outs.append(norm_layer(x)) 194 | 195 | return outs 196 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # VanillaKD: Revisit the Power of Vanilla Knowledge Distillation from Small Scale to Large Scale 2 |

3 | 4 | 5 |

6 | 7 | Official PyTorch implementation of **VanillaKD**, from the following paper: \ 8 | [VanillaKD: Revisit the Power of Vanilla Knowledge Distillation from Small Scale to Large Scale](https://arxiv.org/abs/2305.15781) \ 9 | Zhiwei Hao, Jianyuan Guo, Kai Han, Han Hu, Chang Xu, Yunhe Wang 10 | 11 | 12 | This paper emphasizes the importance of scale in achieving superior results. It reveals that previous KD methods designed solely based on small-scale datasets has underestimated the effectiveness of vanilla KD on large-scale datasets, which is referred as to **small data pitfall**. By incorporating **stronger data augmentation** and **larger datasets**, the performance gap between vanilla KD and other approaches is narrowed: 13 | 14 | 15 | 16 | Without bells and whistles, state-of-the-art results are achieved for ResNet-50, ViT-S, and ConvNeXtV2-T models on ImageNet, showcasing the vanilla KD is elegantly simple but astonishingly effective in large-scale scenarios. 17 | 18 | If you find this project useful in your research, please cite: 19 | 20 | ``` 21 | @article{hao2023vanillakd, 22 | title={VanillaKD: Revisit the Power of Vanilla Knowledge Distillation from Small Scale to Large Scale}, 23 | author={Hao, Zhiwei and Guo, Jianyuan and Han, Kai and Hu, Han and Xu, Chang and Wang, Yunhe}, 24 | journal={arXiv preprint arXiv:2305.15781}, 25 | year={2023} 26 | } 27 | ``` 28 | 29 | ## Model Zoo 30 | 31 | We provide models trained by vanilla KD on ImageNet. 32 | 33 | | name | acc@1 | acc@5 | model | 34 | |:---:|:---:|:---:|:---:| 35 | |resnet50|83.08|96.35|[model](https://github.com/Hao840/vanillaKD/releases/download/checkpoint/resnet50-83.078.pth)| 36 | |vit_tiny_patch16_224|78.11|94.26|[model](https://github.com/Hao840/vanillaKD/releases/download/checkpoint/vit_tiny_patch16_224-78.106.pth)| 37 | |vit_small_patch16_224|84.33|97.09|[model](https://github.com/Hao840/vanillaKD/releases/download/checkpoint/vit_small_patch16_224-84.328.pth)| 38 | |convnextv2_tiny|85.03|97.44|[model](https://github.com/Hao840/vanillaKD/releases/download/checkpoint/convnextv2_tiny-85.030.pth)| 39 | 40 | 41 | ## Usage 42 | First, clone the repository locally: 43 | 44 | ``` 45 | git clone https://github.com/Hao840/vanillaKD.git 46 | ``` 47 | 48 | Then, install PyTorch and [timm 0.6.5](https://github.com/huggingface/pytorch-image-models/tree/v0.6.5) 49 | 50 | ``` 51 | conda install -c pytorch pytorch torchvision 52 | pip install timm==0.6.5 53 | ``` 54 | 55 | Our results are produced with `torch==1.10.2+cu113 torchvision==0.11.3+cu113 timm==0.6.5`. Other versions might also work. 56 | 57 | ### Data preparation 58 | 59 | Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is: 60 | 61 | ``` 62 | │path/to/imagenet/ 63 | ├──train/ 64 | │ ├── n01440764 65 | │ │ ├── n01440764_10026.JPEG 66 | │ │ ├── n01440764_10027.JPEG 67 | │ │ ├── ...... 68 | │ ├── ...... 69 | ├──val/ 70 | │ ├── n01440764 71 | │ │ ├── ILSVRC2012_val_00000293.JPEG 72 | │ │ ├── ILSVRC2012_val_00002138.JPEG 73 | │ │ ├── ...... 74 | │ ├── ...... 75 | ``` 76 | 77 | ### Evaluation 78 | 79 | To evaluate a distilled model on ImageNet val with a single GPU, run: 80 | 81 | ``` 82 | python validate.py /path/to/imagenet --model --checkpoint /path/to/checkpoint 83 | ``` 84 | 85 | 86 | ### Training 87 | 88 | To train a ResNet50 student using BEiTv2-B teacher on ImageNet on a single node with 8 GPUs, run: 89 | 90 | Strategy A2: 91 | 92 | ``` 93 | python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss kd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1 94 | ``` 95 | 96 | Strategy A1: 97 | 98 | ``` 99 | python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss kd --amp --epochs 600 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.01 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.1 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.2 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1 100 | ``` 101 | 102 | 103 | 104 | Commands for reproducing baseline results: 105 | 106 |
107 | 108 | DKD 109 | 110 | Training with ResNet50 student, BEiTv2-B teacher, and strategy A2 for 300 epochs 111 | 112 | ``` 113 | python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss dkd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1 114 | ``` 115 |
116 | 117 | 118 | 119 |
120 | 121 | DIST 122 | 123 | Training with ResNet50 student, BEiTv2-B teacher, and strategy A2 for 300 epochs 124 | 125 | ``` 126 | python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss dist --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1 127 | ``` 128 |
129 | 130 | 131 | 132 |
133 | 134 | Correlation 135 | 136 | Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs 137 | 138 | ``` 139 | python -m torch.distributed.launch --nproc_per_node=8 train-fd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss correlation --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0 140 | ``` 141 |
142 | 143 | 144 | 145 |
146 | 147 | RKD 148 | 149 | Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs 150 | 151 | ``` 152 | python -m torch.distributed.launch --nproc_per_node=8 train-fd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss rkd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0 153 | ``` 154 |
155 | 156 | 157 | 158 |
159 | 160 | ReviewKD 161 | 162 | Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs 163 | 164 | ``` 165 | python -m torch.distributed.launch --nproc_per_node=8 train-fd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss review --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0 166 | ``` 167 |
168 | 169 | 170 | 171 |
172 | 173 | CRD 174 | 175 | Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs 176 | 177 | ``` 178 | python -m torch.distributed.launch --nproc_per_node=8 train-crd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss crd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0 179 | 180 | ``` 181 |
182 | 183 | ## Acknowledgement 184 | 185 | This repository is built using the [timm](https://github.com/rwightman/pytorch-image-models) library, [DKD](https://github.com/megvii-research/mdistiller), [DIST](https://github.com/hunto/DIST_KD), [DeiT](https://github.com/facebookresearch/deit), [BEiT v2](https://github.com/microsoft/unilm/tree/master/beit2), and [ConvNeXt v2](https://github.com/facebookresearch/ConvNeXt-V2) repositories. 186 | -------------------------------------------------------------------------------- /register.py: -------------------------------------------------------------------------------- 1 | from types import MethodType 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from timm.models.beit import Beit 7 | from timm.models.resnet import ResNet 8 | 9 | 10 | class Config: 11 | _feat_dim = { 12 | 'resnet50': ( 13 | (64, 112, 112), (256, 56, 56), (512, 28, 28), (1024, 14, 14), (2048, 7, 7), (2048, None, None)), 14 | 'resnet152': ( 15 | (64, 112, 112), (256, 56, 56), (512, 28, 28), (1024, 14, 14), (2048, 7, 7), (2048, None, None)), 16 | 'swin_small_patch4_window7_224': ( 17 | (96, 56, 56), (192, 28, 28), (384, 14, 14), (768, 7, 7), (768, 7, 7), (768, None, None)), 18 | 'swin_base_patch4_window7_224': ( 19 | (128, 56, 56), (256, 28, 28), (512, 14, 14), (1024, 7, 7), (1024, 7, 7), (1024, None, None)), 20 | 'swin_large_patch4_window7_224': ( 21 | (192, 56, 56), (384, 28, 28), (768, 14, 14), (1536, 7, 7), (1536, 7, 7), (1536, None, None)), 22 | 'beitv2_large_patch16_224': ( 23 | (64, 56, 56), (64, 56, 56), (256, 28, 28), (1024, 14, 14), (1024, 7, 7), (1024, None, None)), 24 | 'bit_r152x2': ( 25 | (128, 112, 112), (512, 56, 56), (1024, 28, 28), (2048, 14, 14), (4096, 7, 7), (4096, 1, 1)), 26 | } 27 | 28 | _kd_feat_index = { 29 | 'resnet50': (1, 2, 3, 4), 30 | 'resnet152': (1, 2, 3, 4), 31 | 'swin_small_patch4_window7_224': (0, 1, 2, 4), 32 | 'swin_base_patch4_window7_224': (0, 1, 2, 4), 33 | 'swin_large_patch4_window7_224': (0, 1, 2, 4), 34 | 'beitv2_large_patch16_224': (1, 2, 3, 4), 35 | 'bit_r152x2': (1, 2, 3, 4), 36 | } 37 | 38 | def get_pre_logit_dim(self, model): 39 | feat_sizes = self._feat_dim[model] 40 | if isinstance(feat_sizes, tuple): 41 | return feat_sizes[-1][0] 42 | else: 43 | return feat_sizes 44 | 45 | def get_used_feature_index(self, model): 46 | index = self._kd_feat_index[model] 47 | if index is None: 48 | raise NotImplementedError(f'undefined feature kd for model {model}') 49 | return index 50 | 51 | def get_feature_size_by_index(self, model, index): 52 | valid_index = self.get_used_feature_index(model) 53 | feat_sizes = self._feat_dim[model] 54 | assert index in valid_index 55 | return feat_sizes[index] 56 | 57 | 58 | config = Config() 59 | 60 | 61 | def register_forward(model): # only resnet have implemented pre_act feat 62 | if isinstance(model, ResNet): # ResNet 63 | model.forward = MethodType(ResNet_forward, model) 64 | model.forward_features = MethodType(ResNet_forward_features, model) 65 | elif isinstance(model, Beit): # Beit 66 | model.forward = MethodType(Beitv2_forward, model) 67 | model.forward_features = MethodType(Beitv2_forward_features, model) 68 | else: 69 | raise NotImplementedError('undefined forward method to get feature, check the exp setting carefully!') 70 | 71 | 72 | def _unpatchify(x, p, remove_token=0): 73 | """ 74 | x: (N, L, patch_size**2 *C) 75 | imgs: (N, C, H, W) 76 | """ 77 | # p = self.patch_embed.patch_size[0] 78 | x = x[:, remove_token:, :] 79 | h = w = int(x.shape[1] ** .5) 80 | assert h * w == x.shape[1] 81 | 82 | x = x.reshape(shape=(x.shape[0], h, w, p, p, -1)) 83 | x = torch.einsum('nhwpqc->nchpwq', x) 84 | imgs = x.reshape(shape=(x.shape[0], -1, h * p, h * p)) 85 | return imgs 86 | 87 | 88 | # ResNet 89 | def bottleneck_forward(self, x): 90 | shortcut = x 91 | 92 | x = self.conv1(x) 93 | x = self.bn1(x) 94 | x = self.act1(x) 95 | 96 | x = self.conv2(x) 97 | x = self.bn2(x) 98 | x = self.drop_block(x) 99 | x = self.act2(x) 100 | x = self.aa(x) 101 | 102 | x = self.conv3(x) 103 | x = self.bn3(x) 104 | 105 | if self.se is not None: 106 | x = self.se(x) 107 | 108 | if self.drop_path is not None: 109 | x = self.drop_path(x) 110 | 111 | if self.downsample is not None: 112 | shortcut = self.downsample(shortcut) 113 | x += shortcut 114 | pre_act_x = x 115 | x = self.act3(pre_act_x) 116 | 117 | return x, pre_act_x 118 | 119 | 120 | def ResNet_forward_features(self, x, requires_feat): 121 | pre_act_feat = [] 122 | feat = [] 123 | x = self.conv1(x) 124 | x = self.bn1(x) 125 | pre_act_feat.append(x) 126 | x = self.act1(x) 127 | feat.append(x) 128 | x = self.maxpool(x) 129 | 130 | for layer in [self.layer1, self.layer2, self.layer3, self.layer4]: 131 | for bottleneck in layer: 132 | x, pre_act_x = bottleneck_forward(bottleneck, x) 133 | 134 | pre_act_feat.append(pre_act_x) 135 | feat.append(x) 136 | 137 | return (x, (pre_act_feat, feat)) if requires_feat else x 138 | 139 | 140 | def ResNet_forward(self, x, requires_feat=False): 141 | if requires_feat: 142 | x, (pre_act_feat, feat) = self.forward_features(x, requires_feat=True) 143 | x = self.forward_head(x, pre_logits=True) 144 | feat.append(x) 145 | pre_act_feat.append(x) 146 | x = self.fc(x) 147 | return x, (pre_act_feat, feat) 148 | else: 149 | x = self.forward_features(x, requires_feat=False) 150 | x = self.forward_head(x) 151 | return x 152 | 153 | 154 | def Beitv2_forward_features(self, x, requires_feat): 155 | x = self.patch_embed(x) 156 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 157 | if self.pos_embed is not None: 158 | x = x + self.pos_embed 159 | x = self.pos_drop(x) 160 | 161 | pre_act_feat = [_unpatchify(x, 4, 1)] # stem 162 | feat = [_unpatchify(x, 4, 1)] # stem 163 | 164 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 165 | for i, blk in enumerate(self.blocks): # fixme: curremt for beitv2L only 166 | x = blk(x, shared_rel_pos_bias=rel_pos_bias) 167 | f = None 168 | if i == 1: 169 | f = _unpatchify(x, 4, 1) 170 | elif i == 3: 171 | f = _unpatchify(x, 2, 1) 172 | elif i == 21: 173 | f = _unpatchify(x, 1, 1) 174 | elif i == 23: 175 | f = F.adaptive_avg_pool2d(_unpatchify(x, 1, 1), (7, 7)) 176 | if f is not None: 177 | pre_act_feat.append(f) 178 | feat.append(f) 179 | x = self.norm(x) 180 | return (x, (pre_act_feat, feat)) if requires_feat else x 181 | 182 | 183 | def Beitv2_forward(self, x, requires_feat=False): 184 | if requires_feat: 185 | x, (pre_act_feat, feat) = self.forward_features(x, requires_feat=True) 186 | x = self.forward_head(x, pre_logits=True) 187 | feat.append(x) 188 | pre_act_feat.append(x) 189 | x = self.head(x) 190 | return x, (pre_act_feat, feat) 191 | else: 192 | x = self.forward_features(x, requires_feat=False) 193 | x = self.forward_head(x) 194 | return x 195 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.2+cu113 2 | torchvision==0.11.3+cu113 3 | timm==0.6.5 -------------------------------------------------------------------------------- /train-crd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ ImageNet Training Script 3 | 4 | This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet 5 | training results with some of the latest networks and training techniques. It favours canonical PyTorch 6 | and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed 7 | and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. 8 | 9 | This script was started from an early version of the PyTorch ImageNet example 10 | (https://github.com/pytorch/examples/tree/master/imagenet) 11 | 12 | NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples 13 | (https://github.com/NVIDIA/apex/tree/master/examples/imagenet) 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 16 | 17 | Modifications by Zhiwei Hao (haozhw@bit.edu.cn) and Jianyuan Guo (jianyuan_guo@outlook.com) 18 | """ 19 | import argparse 20 | import logging 21 | import os 22 | import time 23 | from collections import OrderedDict 24 | from contextlib import suppress 25 | from datetime import datetime 26 | 27 | import numpy as np 28 | import torch 29 | import torch.nn as nn 30 | import torchvision.utils 31 | import yaml 32 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 33 | from torchvision import transforms 34 | 35 | from losses import CRD 36 | from register import config, register_forward 37 | from timm.data import AugMixDataset, create_dataset, create_loader, FastCollateMixup, Mixup, \ 38 | resolve_data_config 39 | from timm.loss import * 40 | from timm.models import convert_splitbn_model, create_model, load_checkpoint, model_parameters, resume_checkpoint, \ 41 | safe_model_name 42 | from timm.optim import create_optimizer_v2, optimizer_kwargs 43 | from timm.scheduler import create_scheduler 44 | from timm.utils import * 45 | from timm.utils import ApexScaler, NativeScaler 46 | from utils import ImageNetInstanceSample, process_feat, setup_default_logging, TimePredictor 47 | import models 48 | 49 | try: 50 | from apex import amp 51 | from apex.parallel import DistributedDataParallel as ApexDDP 52 | from apex.parallel import convert_syncbn_model 53 | 54 | has_apex = True 55 | except ImportError: 56 | has_apex = False 57 | 58 | has_native_amp = False 59 | try: 60 | if getattr(torch.cuda.amp, 'autocast') is not None: 61 | has_native_amp = True 62 | except AttributeError: 63 | pass 64 | 65 | try: 66 | import wandb 67 | 68 | has_wandb = True 69 | except ImportError: 70 | has_wandb = False 71 | 72 | torch.backends.cudnn.benchmark = True 73 | _logger = logging.getLogger('train') 74 | 75 | # The first arg parser parses out only the --config argument, this argument is used to 76 | # load a yaml file containing key-values that override the defaults for the main parser below 77 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) 78 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 79 | help='YAML config file specifying default arguments') 80 | 81 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 82 | 83 | # -------------------------------------- Modified --------------------------------------- 84 | # KD parameters: DIST, DKD, and vanilla KD 85 | parser.add_argument('--kd-loss', default='crd', type=str) 86 | parser.add_argument('--teacher', default='beitv2_base_patch16_224', type=str) 87 | parser.add_argument('--teacher-pretrained', default=None, type=str) # teacher checkpoint path 88 | parser.add_argument('--ori-loss-weight', default=1., type=float) 89 | parser.add_argument('--kd-loss-weight', default=1., type=float) 90 | parser.add_argument('--teacher-resize', default=None, type=int) 91 | parser.add_argument('--student-resize', default=None, type=int) 92 | parser.add_argument('--input-size', default=None, nargs=3, type=int) 93 | 94 | # use torch.cuda.empty_cache() to save GPU memory 95 | parser.add_argument('--economic', action='store_true') 96 | 97 | # eval every 'eval-interval' epochs before epochs * eval_interval_end 98 | parser.add_argument('--eval-interval', type=int, default=1) 99 | parser.add_argument('--eval-interval-end', type=float, default=0.75) 100 | # --------------------------------------------------------------------------------------- 101 | 102 | # Dataset parameters 103 | parser.add_argument('data_dir', metavar='DIR', 104 | help='path to dataset') 105 | parser.add_argument('--dataset', '-d', metavar='NAME', default='', 106 | help='dataset type (default: ImageFolder/ImageTar if empty)') 107 | parser.add_argument('--train-split', metavar='NAME', default='train', 108 | help='dataset train split (default: train)') 109 | parser.add_argument('--val-split', metavar='NAME', default='validation', 110 | help='dataset validation split (default: validation)') 111 | parser.add_argument('--dataset-download', action='store_true', default=False, 112 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.') 113 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', 114 | help='path to class to idx mapping file (default: "")') 115 | 116 | # Model parameters 117 | parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL', 118 | help='Name of model to train (default: "resnet50"') 119 | parser.add_argument('--pretrained', action='store_true', default=False, 120 | help='Start with pretrained version of specified network (if avail)') 121 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 122 | help='Initialize model from this checkpoint (default: none)') 123 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 124 | help='Resume full model and optimizer state from checkpoint (default: none)') 125 | parser.add_argument('--no-resume-opt', action='store_true', default=False, 126 | help='prevent resume of optimizer state when resuming model') 127 | parser.add_argument('--num-classes', type=int, default=1000, metavar='N', 128 | help='number of label classes (Model default if None)') 129 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 130 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 131 | parser.add_argument('--img-size', type=int, default=None, metavar='N', 132 | help='Image patch size (default: None => model default)') 133 | parser.add_argument('--crop-pct', default=None, type=float, 134 | metavar='N', help='Input image center crop percent (for validation only)') 135 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 136 | help='Override mean pixel value of dataset') 137 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 138 | help='Override std deviation of dataset') 139 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 140 | help='Image resize interpolation type (overrides model)') 141 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', 142 | help='Input batch size for training (default: 128)') 143 | parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', 144 | help='Validation batch size override (default: None)') 145 | parser.add_argument('--channels-last', action='store_true', default=False, 146 | help='Use channels_last memory layout') 147 | parser.add_argument('--torchscript', dest='torchscript', action='store_true', 148 | help='torch.jit.script the full model') 149 | parser.add_argument('--fuser', default='', type=str, 150 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") 151 | parser.add_argument('--grad-checkpointing', action='store_true', default=False, 152 | help='Enable gradient checkpointing through model blocks/stages') 153 | 154 | # Optimizer parameters 155 | parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', 156 | help='Optimizer (default: "sgd"') 157 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 158 | help='Optimizer Epsilon (default: None, use opt default)') 159 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 160 | help='Optimizer Betas (default: None, use opt default)') 161 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 162 | help='Optimizer momentum (default: 0.9)') 163 | parser.add_argument('--weight-decay', type=float, default=2e-5, 164 | help='weight decay (default: 2e-5)') 165 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 166 | help='Clip gradient norm (default: None, no clipping)') 167 | parser.add_argument('--clip-mode', type=str, default='norm', 168 | help='Gradient clipping mode. One of ("norm", "value", "agc")') 169 | parser.add_argument('--layer-decay', type=float, default=None, 170 | help='layer-wise learning rate decay (default: None)') 171 | 172 | # Learning rate schedule parameters 173 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 174 | help='LR scheduler (default: "step"') 175 | parser.add_argument('--lr', type=float, default=0.05, metavar='LR', 176 | help='learning rate (default: 0.05)') 177 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 178 | help='learning rate noise on/off epoch percentages') 179 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 180 | help='learning rate noise limit percent (default: 0.67)') 181 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 182 | help='learning rate noise std-dev (default: 1.0)') 183 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 184 | help='learning rate cycle len multiplier (default: 1.0)') 185 | parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', 186 | help='amount to decay each learning rate cycle (default: 0.5)') 187 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 188 | help='learning rate cycle limit, cycles enabled if > 1') 189 | parser.add_argument('--lr-k-decay', type=float, default=1.0, 190 | help='learning rate k-decay for cosine/poly (default: 1.0)') 191 | parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', 192 | help='warmup learning rate (default: 0.0001)') 193 | parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', 194 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 195 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 196 | help='number of epochs to train (default: 300)') 197 | parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', 198 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') 199 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N', 200 | help='manual epoch number (useful on restarts)') 201 | parser.add_argument('--decay-epochs', type=float, default=100, metavar='N', 202 | help='epoch interval to decay LR') 203 | parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', 204 | help='epochs to warmup LR, if scheduler supports') 205 | parser.add_argument('--cooldown-epochs', type=int, default=0, metavar='N', 206 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 207 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 208 | help='patience epochs for Plateau LR scheduler (default: 10') 209 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 210 | help='LR decay rate (default: 0.1)') 211 | 212 | # Augmentation & regularization parameters 213 | parser.add_argument('--no-aug', action='store_true', default=False, 214 | help='Disable all training augmentation, override other train aug args') 215 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 216 | help='Random resize scale (default: 0.08 1.0)') 217 | parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', 218 | help='Random resize aspect ratio (default: 0.75 1.33)') 219 | parser.add_argument('--hflip', type=float, default=0.5, 220 | help='Horizontal flip training aug probability') 221 | parser.add_argument('--vflip', type=float, default=0., 222 | help='Vertical flip training aug probability') 223 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 224 | help='Color jitter factor (default: 0.4)') 225 | parser.add_argument('--aa', type=str, default=None, metavar='NAME', 226 | help='Use AutoAugment policy. "v0" or "original". (default: None)'), 227 | parser.add_argument('--aug-repeats', type=float, default=0, 228 | help='Number of augmentation repetitions (distributed training only) (default: 0)') 229 | parser.add_argument('--aug-splits', type=int, default=0, 230 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 231 | parser.add_argument('--jsd-loss', action='store_true', default=False, 232 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 233 | parser.add_argument('--bce-loss', action='store_true', default=False, 234 | help='Enable BCE loss w/ Mixup/CutMix use.') 235 | parser.add_argument('--bce-target-thresh', type=float, default=None, 236 | help='Threshold for binarizing softened BCE targets (default: None, disabled)') 237 | parser.add_argument('--reprob', type=float, default=0., metavar='PCT', 238 | help='Random erase prob (default: 0.)') 239 | parser.add_argument('--remode', type=str, default='pixel', 240 | help='Random erase mode (default: "pixel")') 241 | parser.add_argument('--recount', type=int, default=1, 242 | help='Random erase count (default: 1)') 243 | parser.add_argument('--resplit', action='store_true', default=False, 244 | help='Do not random erase first (clean) augmentation split') 245 | parser.add_argument('--mixup', type=float, default=0.0, 246 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 247 | parser.add_argument('--cutmix', type=float, default=0.0, 248 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 249 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 250 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 251 | parser.add_argument('--mixup-prob', type=float, default=1.0, 252 | help='Probability of performing mixup or cutmix when either/both is enabled') 253 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 254 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 255 | parser.add_argument('--mixup-mode', type=str, default='batch', 256 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 257 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 258 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 259 | parser.add_argument('--smoothing', type=float, default=0.1, 260 | help='Label smoothing (default: 0.1)') 261 | parser.add_argument('--train-interpolation', type=str, default='random', 262 | help='Training interpolation (random, bilinear, bicubic default: "random")') 263 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 264 | help='Dropout rate (default: 0.)') 265 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 266 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 267 | parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', 268 | help='Drop path rate (default: None)') 269 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 270 | help='Drop block rate (default: None)') 271 | 272 | # Batch norm parameters (only works with gen_efficientnet based models currently) 273 | parser.add_argument('--bn-momentum', type=float, default=None, 274 | help='BatchNorm momentum override (if not None)') 275 | parser.add_argument('--bn-eps', type=float, default=None, 276 | help='BatchNorm epsilon override (if not None)') 277 | parser.add_argument('--sync-bn', action='store_true', 278 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 279 | parser.add_argument('--dist-bn', type=str, default='reduce', 280 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 281 | parser.add_argument('--split-bn', action='store_true', 282 | help='Enable separate BN layers per augmentation split.') 283 | 284 | # Model Exponential Moving Average 285 | parser.add_argument('--model-ema', action='store_true', default=False, 286 | help='Enable tracking moving average of model weights') 287 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, 288 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 289 | parser.add_argument('--model-ema-decay', type=float, default=0.9998, 290 | help='decay factor for model weights moving average (default: 0.9998)') 291 | 292 | # Misc 293 | parser.add_argument('--seed', type=int, default=42, metavar='S', 294 | help='random seed (default: 42)') 295 | parser.add_argument('--worker-seeding', type=str, default='all', 296 | help='worker seed mode (default: all)') 297 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 298 | help='how many batches to wait before logging training status') 299 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', 300 | help='how many batches to wait before writing recovery checkpoint') 301 | parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', 302 | help='number of checkpoints to keep (default: 10)') 303 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N', 304 | help='how many training processes to use (default: 4)') 305 | parser.add_argument('--save-images', action='store_true', default=False, 306 | help='save images of input bathes every log interval for debugging') 307 | parser.add_argument('--amp', action='store_true', default=False, 308 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 309 | parser.add_argument('--apex-amp', action='store_true', default=False, 310 | help='Use NVIDIA Apex AMP mixed precision') 311 | parser.add_argument('--native-amp', action='store_true', default=False, 312 | help='Use Native Torch AMP mixed precision') 313 | parser.add_argument('--no-ddp-bb', action='store_true', default=False, 314 | help='Force broadcast buffers for native DDP to off.') 315 | parser.add_argument('--pin-mem', action='store_true', default=False, 316 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 317 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 318 | help='disable fast prefetcher') 319 | parser.set_defaults(no_prefetcher=True) 320 | parser.add_argument('--output', default='', type=str, metavar='PATH', 321 | help='path to output folder (default: none, current dir)') 322 | parser.add_argument('--experiment', default='', type=str, metavar='NAME', 323 | help='name of train experiment, name of sub-folder for output') 324 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 325 | help='Best metric (default: "top1"') 326 | parser.add_argument('--tta', type=int, default=0, metavar='N', 327 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 328 | parser.add_argument("--local_rank", default=0, type=int) 329 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 330 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 331 | parser.add_argument('--log-wandb', action='store_true', default=False, 332 | help='log training and validation metrics to wandb') 333 | 334 | 335 | def _parse_args(): 336 | # Do we have a config file to parse? 337 | args_config, remaining = config_parser.parse_known_args() 338 | if args_config.config: 339 | with open(args_config.config, 'r') as f: 340 | cfg = yaml.safe_load(f) 341 | parser.set_defaults(**cfg) 342 | 343 | # The main arg parser parses the rest of the args, the usual 344 | # defaults will have been overridden if config file specified. 345 | args = parser.parse_args(remaining) 346 | 347 | # Cache the args as a text string to save them in the output dir later 348 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 349 | return args, args_text 350 | 351 | 352 | def main(): 353 | setup_default_logging(_logger, log_path='train.log') 354 | args, args_text = _parse_args() 355 | 356 | if args.log_wandb: 357 | if has_wandb: 358 | wandb.init(project=args.experiment, config=args) 359 | else: 360 | _logger.warning("You've requested to log metrics to wandb but package not found. " 361 | "Metrics not being logged to wandb, try `pip install wandb`") 362 | 363 | args.prefetcher = not args.no_prefetcher 364 | args.distributed = False 365 | if 'WORLD_SIZE' in os.environ: 366 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 367 | args.device = 'cuda:0' 368 | args.world_size = 1 369 | args.rank = 0 # global rank 370 | if args.distributed: 371 | args.device = 'cuda:%d' % args.local_rank 372 | torch.cuda.set_device(args.local_rank) 373 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 374 | args.world_size = torch.distributed.get_world_size() 375 | args.rank = torch.distributed.get_rank() 376 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 377 | % (args.rank, args.world_size)) 378 | else: 379 | _logger.info('Training with a single process on 1 GPUs.') 380 | assert args.rank >= 0 381 | 382 | # resolve AMP arguments based on PyTorch / Apex availability 383 | use_amp = None 384 | if args.amp: 385 | # `--amp` chooses native amp before apex (APEX ver not actively maintained) 386 | if has_native_amp: 387 | args.native_amp = True 388 | elif has_apex: 389 | args.apex_amp = True 390 | if args.apex_amp and has_apex: 391 | use_amp = 'apex' 392 | elif args.native_amp and has_native_amp: 393 | use_amp = 'native' 394 | elif args.apex_amp or args.native_amp: 395 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. " 396 | "Install NVIDA apex or upgrade to PyTorch 1.6") 397 | 398 | random_seed(args.seed, args.rank) 399 | 400 | if args.fuser: 401 | set_jit_fuser(args.fuser) 402 | 403 | # ------------------------------------ Modified ------------------------------------- 404 | teacher = create_model( 405 | args.teacher, 406 | checkpoint_path=args.teacher_pretrained, 407 | num_classes=args.num_classes) 408 | register_forward(teacher) 409 | teacher = teacher.cuda() 410 | teacher.eval() 411 | # ------------------------------------------------------------------------------------ 412 | 413 | model = create_model( 414 | args.model, 415 | pretrained=args.pretrained, 416 | num_classes=args.num_classes, 417 | drop_rate=args.drop, 418 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path 419 | drop_path_rate=args.drop_path, 420 | drop_block_rate=args.drop_block, 421 | global_pool=args.gp, 422 | bn_momentum=args.bn_momentum, 423 | bn_eps=args.bn_eps, 424 | scriptable=args.torchscript, 425 | checkpoint_path=args.initial_checkpoint) 426 | register_forward(model) 427 | 428 | if args.num_classes is None: 429 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' 430 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly 431 | 432 | if args.grad_checkpointing: 433 | model.set_grad_checkpointing(enable=True) 434 | 435 | if args.local_rank == 0: 436 | _logger.info( 437 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') 438 | 439 | data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) 440 | 441 | # setup augmentation batch splits for contrastive loss or split bn 442 | num_aug_splits = 0 443 | if args.aug_splits > 0: 444 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 445 | num_aug_splits = args.aug_splits 446 | 447 | # enable split bn (separate bn stats per batch-portion) 448 | if args.split_bn: 449 | assert num_aug_splits > 1 or args.resplit 450 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 451 | 452 | # create the train and eval datasets 453 | dataset_train = ImageNetInstanceSample(root=f'{args.data_dir}/train', name=args.dataset, class_map=args.class_map, 454 | load_bytes=False, is_sample=True, k=16384) 455 | dataset_eval = create_dataset( 456 | args.dataset, root=args.data_dir, split=args.val_split, is_training=False, 457 | class_map=args.class_map, 458 | download=args.dataset_download, 459 | batch_size=args.batch_size) 460 | 461 | # move model to GPU, enable channels last layout if set 462 | if args.kd_loss == 'crd': 463 | kd_loss_fn = CRD(feat_s_channel=config.get_pre_logit_dim(args.model), 464 | feat_t_channel=config.get_pre_logit_dim(args.teacher), 465 | feat_dim=128, num_data=len(dataset_train), k=16384, 466 | momentum=0.5, temperature=0.07) 467 | requires_feat = True 468 | else: 469 | raise NotImplementedError(f'this script only supports crd loss. for other kd loss, please refer to train-kd.py') 470 | 471 | model.kd_loss_fn = kd_loss_fn 472 | model.cuda() 473 | 474 | if args.channels_last: 475 | model = model.to(memory_format=torch.channels_last) 476 | 477 | # setup synchronized BatchNorm for distributed training 478 | if args.distributed and args.sync_bn: 479 | assert not args.split_bn 480 | if has_apex and use_amp == 'apex': 481 | # Apex SyncBN preferred unless native amp is activated 482 | model = convert_syncbn_model(model) 483 | else: 484 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 485 | if args.local_rank == 0: 486 | _logger.info( 487 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 488 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 489 | 490 | if args.torchscript: 491 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' 492 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' 493 | model = torch.jit.script(model) 494 | 495 | optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) 496 | 497 | # setup automatic mixed-precision (AMP) loss scaling and op casting 498 | amp_autocast = suppress # do nothing 499 | loss_scaler = None 500 | if use_amp == 'apex': 501 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 502 | loss_scaler = ApexScaler() 503 | if args.local_rank == 0: 504 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 505 | elif use_amp == 'native': 506 | amp_autocast = torch.cuda.amp.autocast 507 | loss_scaler = NativeScaler() 508 | if args.local_rank == 0: 509 | _logger.info('Using native Torch AMP. Training in mixed precision.') 510 | else: 511 | if args.local_rank == 0: 512 | _logger.info('AMP not enabled. Training in float32.') 513 | 514 | # optionally resume from a checkpoint 515 | resume_epoch = None 516 | if args.resume: 517 | resume_epoch = resume_checkpoint( 518 | model, args.resume, 519 | optimizer=None if args.no_resume_opt else optimizer, 520 | loss_scaler=None if args.no_resume_opt else loss_scaler, 521 | log_info=args.local_rank == 0) 522 | 523 | # setup exponential moving average of model weights, SWA could be used here too 524 | model_ema = None 525 | if args.model_ema: 526 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper 527 | model_ema = ModelEmaV2( 528 | model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) 529 | if args.resume: 530 | load_checkpoint(model_ema.module, args.resume, use_ema=True) 531 | 532 | # setup distributed training 533 | if args.distributed: 534 | if has_apex and use_amp == 'apex': 535 | # Apex DDP preferred unless native amp is activated 536 | if args.local_rank == 0: 537 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 538 | model = ApexDDP(model, delay_allreduce=True) 539 | else: 540 | if args.local_rank == 0: 541 | _logger.info("Using native Torch DistributedDataParallel.") 542 | model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) 543 | # NOTE: EMA model does not need to be wrapped by DDP 544 | 545 | # setup learning rate schedule and starting epoch 546 | lr_scheduler, num_epochs = create_scheduler(args, optimizer) 547 | start_epoch = 0 548 | if args.start_epoch is not None: 549 | # a specified start_epoch will always override the resume epoch 550 | start_epoch = args.start_epoch 551 | elif resume_epoch is not None: 552 | start_epoch = resume_epoch 553 | if lr_scheduler is not None and start_epoch > 0: 554 | lr_scheduler.step(start_epoch) 555 | 556 | if args.local_rank == 0: 557 | _logger.info('Scheduled epochs: {}'.format(num_epochs)) 558 | 559 | # setup mixup / cutmix 560 | collate_fn = None 561 | mixup_fn = None 562 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 563 | if mixup_active: 564 | mixup_args = dict( 565 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 566 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 567 | label_smoothing=args.smoothing, num_classes=args.num_classes) 568 | if args.prefetcher: 569 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) 570 | collate_fn = FastCollateMixup(**mixup_args) 571 | else: 572 | mixup_fn = Mixup(**mixup_args) 573 | 574 | # wrap dataset in AugMix helper 575 | if num_aug_splits > 1: 576 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 577 | 578 | # create data loaders w/ augmentation pipeiine 579 | train_interpolation = args.train_interpolation 580 | if args.no_aug or not train_interpolation: 581 | train_interpolation = data_config['interpolation'] 582 | loader_train = create_loader( 583 | dataset_train, 584 | input_size=data_config['input_size'], 585 | batch_size=args.batch_size, 586 | is_training=True, 587 | use_prefetcher=args.prefetcher, 588 | no_aug=args.no_aug, 589 | re_prob=args.reprob, 590 | re_mode=args.remode, 591 | re_count=args.recount, 592 | re_split=args.resplit, 593 | scale=args.scale, 594 | ratio=args.ratio, 595 | hflip=args.hflip, 596 | vflip=args.vflip, 597 | color_jitter=args.color_jitter, 598 | auto_augment=args.aa, 599 | num_aug_repeats=args.aug_repeats, 600 | num_aug_splits=num_aug_splits, 601 | interpolation=train_interpolation, 602 | mean=data_config['mean'], 603 | std=data_config['std'], 604 | num_workers=args.workers, 605 | distributed=args.distributed, 606 | collate_fn=collate_fn, 607 | pin_memory=args.pin_mem, 608 | use_multi_epochs_loader=args.use_multi_epochs_loader, 609 | worker_seeding=args.worker_seeding, 610 | ) 611 | 612 | loader_eval = create_loader( 613 | dataset_eval, 614 | input_size=(3, 224, 224), 615 | batch_size=args.validation_batch_size or args.batch_size, 616 | is_training=False, 617 | use_prefetcher=args.prefetcher, 618 | interpolation=data_config['interpolation'], 619 | mean=data_config['mean'], 620 | std=data_config['std'], 621 | num_workers=args.workers, 622 | distributed=args.distributed, 623 | crop_pct=data_config['crop_pct'], 624 | pin_memory=args.pin_mem, 625 | ) 626 | 627 | # setup loss function 628 | if args.jsd_loss: 629 | assert num_aug_splits > 1 # JSD only valid with aug splits set 630 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) 631 | elif mixup_active: 632 | # smoothing is handled with mixup target transform which outputs sparse, soft targets 633 | if args.bce_loss: 634 | train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) 635 | else: 636 | train_loss_fn = SoftTargetCrossEntropy() 637 | elif args.smoothing: 638 | if args.bce_loss: 639 | train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh) 640 | else: 641 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 642 | else: 643 | train_loss_fn = nn.CrossEntropyLoss() 644 | train_loss_fn = train_loss_fn.cuda() 645 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 646 | 647 | teacher_resizer = student_resizer = None 648 | if args.teacher_resize is not None: 649 | teacher_resizer = transforms.Resize(args.teacher_resize).cuda() 650 | if args.student_resize is not None: 651 | student_resizer = transforms.Resize(args.student_resize).cuda() 652 | 653 | # setup checkpoint saver and eval metric tracking 654 | eval_metric = args.eval_metric 655 | best_metric = None 656 | best_epoch = None 657 | saver = None 658 | ema_saver = None 659 | output_dir = None 660 | if args.rank == 0: 661 | if args.experiment: 662 | exp_name = args.experiment 663 | else: 664 | exp_name = '-'.join([ 665 | datetime.now().strftime("%Y%m%d-%H%M%S"), 666 | safe_model_name(args.model), 667 | str(data_config['input_size'][-1]) 668 | ]) 669 | output_dir = get_outdir(args.output if args.output else './output/train', exp_name) 670 | decreasing = True if eval_metric == 'loss' else False 671 | saver_dir = os.path.join(output_dir, 'checkpoint') 672 | os.makedirs(saver_dir) 673 | saver = CheckpointSaver( 674 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, 675 | checkpoint_dir=saver_dir, recovery_dir=saver_dir, decreasing=decreasing, 676 | max_history=args.checkpoint_hist) 677 | if model_ema is not None: 678 | ema_saver_dir = os.path.join(output_dir, 'ema_checkpoint') 679 | os.makedirs(ema_saver_dir) 680 | ema_saver = CheckpointSaver( 681 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, 682 | checkpoint_dir=ema_saver_dir, recovery_dir=ema_saver_dir, decreasing=decreasing, 683 | max_history=args.checkpoint_hist) 684 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 685 | f.write(args_text) 686 | 687 | try: 688 | tp = TimePredictor(num_epochs - start_epoch) 689 | for epoch in range(start_epoch, num_epochs): 690 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): 691 | loader_train.sampler.set_epoch(epoch) 692 | 693 | train_metrics = train_one_epoch( 694 | epoch, model, teacher, loader_train, optimizer, train_loss_fn, args, 695 | requires_feat=requires_feat, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, 696 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, 697 | teacher_resizer=teacher_resizer, student_resizer=student_resizer) 698 | 699 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 700 | if args.local_rank == 0: 701 | _logger.info("Distributing BatchNorm running means and vars") 702 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 703 | 704 | is_eval = epoch > int(args.eval_interval_end * args.epochs) or epoch % args.eval_interval == 0 705 | if is_eval: 706 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) 707 | 708 | if saver is not None: 709 | # save proper checkpoint with eval metric 710 | save_metric = eval_metrics[eval_metric] 711 | best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) 712 | 713 | if model_ema is not None and not args.model_ema_force_cpu: 714 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 715 | distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') 716 | ema_eval_metrics = validate( 717 | model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, 718 | log_suffix=' (EMA)') 719 | eval_metrics = ema_eval_metrics 720 | 721 | if ema_saver is not None: 722 | # save proper checkpoint with eval metric 723 | save_metric = eval_metrics[eval_metric] 724 | best_metric, best_epoch = ema_saver.save_checkpoint(epoch, metric=save_metric) 725 | 726 | if output_dir is not None: 727 | update_summary( 728 | epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), 729 | write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) 730 | 731 | metrics = eval_metrics[eval_metric] 732 | else: 733 | metrics = None 734 | 735 | if lr_scheduler is not None: 736 | # step LR for next epoch 737 | lr_scheduler.step(epoch + 1, metrics) 738 | 739 | tp.update() 740 | if args.rank == 0: 741 | print(f'Will finish at {tp.get_pred_text()}') 742 | print(f'Avg running time of latest {len(tp.time_list)} epochs: {np.mean(tp.time_list):.2f}s/ep.') 743 | 744 | except KeyboardInterrupt: 745 | pass 746 | if best_metric is not None: 747 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) 748 | 749 | if args.rank == 0: 750 | os.system(f'mv train.log {output_dir}') 751 | 752 | 753 | def train_one_epoch( 754 | epoch, model, teacher, loader, optimizer, loss_fn, args, requires_feat=False, 755 | lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, loss_scaler=None, 756 | model_ema=None, mixup_fn=None, teacher_resizer=None, student_resizer=None): 757 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 758 | if args.prefetcher and loader.mixup_enabled: 759 | loader.mixup_enabled = False 760 | elif mixup_fn is not None: 761 | mixup_fn.mixup_enabled = False 762 | 763 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 764 | batch_time_m = AverageMeter() 765 | data_time_m = AverageMeter() 766 | losses_m = AverageMeter() 767 | losses_ori_m = AverageMeter() 768 | losses_kd_m = AverageMeter() 769 | 770 | model.train() 771 | 772 | end = time.time() 773 | last_idx = len(loader) - 1 774 | num_updates = epoch * len(loader) 775 | for batch_idx, (input, target, index, contrastive_index) in enumerate(loader): 776 | last_batch = batch_idx == last_idx 777 | data_time_m.update(time.time() - end) 778 | if not args.prefetcher: 779 | input, target = input.cuda(), target.cuda() 780 | if mixup_fn is not None: 781 | input, target = mixup_fn(input, target) 782 | if args.channels_last: 783 | input = input.contiguous(memory_format=torch.channels_last) 784 | index = index.cuda() 785 | contrastive_index = contrastive_index.cuda() 786 | 787 | with amp_autocast(): 788 | 789 | # --------------------------------- Modified -------------------------------- 790 | if student_resizer is not None: 791 | student_input = student_resizer(input) 792 | else: 793 | student_input = input 794 | if teacher_resizer is not None: 795 | teacher_input = teacher_resizer(input) 796 | else: 797 | teacher_input = input 798 | 799 | if args.economic: 800 | torch.cuda.empty_cache() 801 | 802 | with torch.no_grad(): 803 | output_t, feat_t = teacher(teacher_input, requires_feat=requires_feat) 804 | 805 | if args.economic: 806 | torch.cuda.empty_cache() 807 | 808 | output, feat = model(student_input, requires_feat=requires_feat) 809 | 810 | loss_ori = args.ori_loss_weight * loss_fn(output, target) 811 | 812 | try: 813 | kd_loss_fn = model.module.kd_loss_fn 814 | except AttributeError: 815 | kd_loss_fn = model.kd_loss_fn 816 | loss_kd = args.kd_loss_weight * kd_loss_fn(z_s=output, z_t=output_t, 817 | feature_student=process_feat(kd_loss_fn, feat), 818 | feature_teacher=process_feat(kd_loss_fn, feat_t), 819 | index=index, contrastive_index=contrastive_index) 820 | loss = loss_ori + loss_kd 821 | # ---------------------------------------------------------------------------- 822 | 823 | if not args.distributed: 824 | # --------------------------------- Modified -------------------------------- 825 | losses_m.update(loss.item(), input.size(0)) 826 | losses_ori_m.update(loss_ori.item(), input.size(0)) 827 | losses_kd_m.update(loss_kd.item(), input.size(0)) 828 | # ---------------------------------------------------------------------------- 829 | 830 | optimizer.zero_grad() 831 | if loss_scaler is not None: 832 | loss_scaler( 833 | loss, optimizer, 834 | clip_grad=args.clip_grad, clip_mode=args.clip_mode, 835 | parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), 836 | create_graph=second_order) 837 | else: 838 | loss.backward(create_graph=second_order) 839 | if args.clip_grad is not None: 840 | dispatch_clip_grad( 841 | model_parameters(model, exclude_head='agc' in args.clip_mode), 842 | value=args.clip_grad, mode=args.clip_mode) 843 | optimizer.step() 844 | 845 | if model_ema is not None: 846 | model_ema.update(model) 847 | 848 | torch.cuda.synchronize() 849 | num_updates += 1 850 | batch_time_m.update(time.time() - end) 851 | if last_batch or batch_idx % args.log_interval == 0: 852 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 853 | lr = sum(lrl) / len(lrl) 854 | 855 | # ----------------------------------- Modified ---------------------------------- 856 | if args.distributed: 857 | reduced_loss = reduce_tensor(loss.data, args.world_size) 858 | reduced_loss_ori = reduce_tensor(loss_ori.data, args.world_size) 859 | reduced_loss_kd = reduce_tensor(loss_kd.data, args.world_size) 860 | losses_m.update(reduced_loss.item(), input.size(0)) 861 | losses_ori_m.update(reduced_loss_ori.item(), input.size(0)) 862 | losses_kd_m.update(reduced_loss_kd.item(), input.size(0)) 863 | 864 | if args.local_rank == 0: 865 | _logger.info( 866 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 867 | 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' 868 | 'Loss_ori: {loss_ori.val:#.4g} ({loss_ori.avg:#.3g}) ' 869 | 'Loss_kd: {loss_kd.val:#.4g} ({loss_kd.avg:#.3g}) ' 870 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 871 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 872 | 'LR: {lr:.3e} ' 873 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 874 | epoch, 875 | batch_idx, len(loader), 876 | 100. * batch_idx / last_idx, 877 | loss=losses_m, 878 | loss_ori=losses_ori_m, 879 | loss_kd=losses_kd_m, 880 | batch_time=batch_time_m, 881 | rate=input.size(0) * args.world_size / batch_time_m.val, 882 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg, 883 | lr=lr, 884 | data_time=data_time_m)) 885 | # -------------------------------------------------------------------------------- 886 | 887 | if args.save_images and output_dir: 888 | torchvision.utils.save_image( 889 | input, 890 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), 891 | padding=0, 892 | normalize=True) 893 | 894 | if saver is not None and args.recovery_interval and ( 895 | last_batch or (batch_idx + 1) % args.recovery_interval == 0): 896 | saver.save_recovery(epoch, batch_idx=batch_idx) 897 | 898 | if lr_scheduler is not None: 899 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) 900 | 901 | end = time.time() 902 | # end for 903 | 904 | if hasattr(optimizer, 'sync_lookahead'): 905 | optimizer.sync_lookahead() 906 | 907 | return OrderedDict([('loss', losses_m.avg)]) 908 | 909 | 910 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): 911 | batch_time_m = AverageMeter() 912 | losses_m = AverageMeter() 913 | top1_m = AverageMeter() 914 | top5_m = AverageMeter() 915 | 916 | model.eval() 917 | 918 | end = time.time() 919 | last_idx = len(loader) - 1 920 | with torch.no_grad(): 921 | for batch_idx, (input, target) in enumerate(loader): 922 | last_batch = batch_idx == last_idx 923 | if not args.prefetcher: 924 | input = input.cuda() 925 | target = target.cuda() 926 | if args.channels_last: 927 | input = input.contiguous(memory_format=torch.channels_last) 928 | 929 | with amp_autocast(): 930 | output = model(input) 931 | if isinstance(output, (tuple, list)): 932 | output = output[0] 933 | 934 | # augmentation reduction 935 | reduce_factor = args.tta 936 | if reduce_factor > 1: 937 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) 938 | target = target[0:target.size(0):reduce_factor] 939 | 940 | loss = loss_fn(output, target) 941 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 942 | 943 | if args.distributed: 944 | reduced_loss = reduce_tensor(loss.data, args.world_size) 945 | acc1 = reduce_tensor(acc1, args.world_size) 946 | acc5 = reduce_tensor(acc5, args.world_size) 947 | else: 948 | reduced_loss = loss.data 949 | 950 | torch.cuda.synchronize() 951 | 952 | losses_m.update(reduced_loss.item(), input.size(0)) 953 | top1_m.update(acc1.item(), output.size(0)) 954 | top5_m.update(acc5.item(), output.size(0)) 955 | 956 | batch_time_m.update(time.time() - end) 957 | end = time.time() 958 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): 959 | log_name = 'Test' + log_suffix 960 | _logger.info( 961 | '{0}: [{1:>4d}/{2}] ' 962 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 963 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 964 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 965 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 966 | log_name, batch_idx, last_idx, batch_time=batch_time_m, 967 | loss=losses_m, top1=top1_m, top5=top5_m)) 968 | 969 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) 970 | 971 | return metrics 972 | 973 | 974 | if __name__ == '__main__': 975 | main() 976 | -------------------------------------------------------------------------------- /train-fd.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | """ ImageNet Training Script 3 | 4 | This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet 5 | training results with some of the latest networks and training techniques. It favours canonical PyTorch 6 | and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed 7 | and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. 8 | 9 | This script was started from an early version of the PyTorch ImageNet example 10 | (https://github.com/pytorch/examples/tree/master/imagenet) 11 | 12 | NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples 13 | (https://github.com/NVIDIA/apex/tree/master/examples/imagenet) 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 16 | 17 | Modifications by Zhiwei Hao (haozhw@bit.edu.cn) and Jianyuan Guo (jianyuan_guo@outlook.com) 18 | """ 19 | import argparse 20 | import logging 21 | import logging.handlers 22 | import os 23 | import time 24 | from collections import OrderedDict 25 | from contextlib import suppress 26 | from copy import deepcopy 27 | from datetime import datetime 28 | 29 | import numpy as np 30 | import torch 31 | import torch.nn as nn 32 | import yaml 33 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 34 | from torchvision import transforms 35 | 36 | from losses import Correlation, ReviewKD, RKD 37 | from register import config, register_forward 38 | from timm.data import AugMixDataset, create_dataset, create_loader, FastCollateMixup, resolve_data_config 39 | from timm.loss import * 40 | from timm.models import convert_splitbn_model, create_model, model_parameters, safe_model_name 41 | from timm.optim import create_optimizer_v2 42 | from timm.scheduler import create_scheduler 43 | from timm.utils import * 44 | from timm.utils import ApexScaler, NativeScaler 45 | from utils import CheckpointSaverWithLogger, MultiSmoothingMixup, process_feat, setup_default_logging, TimePredictor 46 | import models 47 | 48 | try: 49 | from apex import amp 50 | from apex.parallel import DistributedDataParallel as ApexDDP 51 | from apex.parallel import convert_syncbn_model 52 | 53 | has_apex = True 54 | except ImportError: 55 | has_apex = False 56 | 57 | has_native_amp = False 58 | try: 59 | if getattr(torch.cuda.amp, 'autocast') is not None: 60 | has_native_amp = True 61 | except AttributeError: 62 | pass 63 | 64 | try: 65 | import wandb 66 | 67 | has_wandb = True 68 | except ImportError: 69 | has_wandb = False 70 | 71 | torch.backends.cudnn.benchmark = True 72 | _logger = logging.getLogger('train') 73 | 74 | # The first arg parser parses out only the --config argument, this argument is used to 75 | # load a yaml file containing key-values that override the defaults for the main parser below 76 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) 77 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 78 | help='YAML config file specifying default arguments') 79 | 80 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 81 | 82 | # --------------------------------------------------------------------------------------- 83 | # KD parameters: CC, review, and RKD 84 | parser.add_argument('--teacher', default='beitv2_base_patch16_224', type=str) 85 | parser.add_argument('--teacher-pretrained', default=None, type=str) # teacher checkpoint path 86 | parser.add_argument('--teacher-resize', default=None, type=int) 87 | parser.add_argument('--student-resize', default=None, type=int) 88 | parser.add_argument('--input-size', default=None, nargs=3, type=int) 89 | 90 | # use torch.cuda.empty_cache() to save GPU memory 91 | parser.add_argument('--economic', action='store_true') 92 | 93 | # eval every 'eval-interval' epochs before epochs * eval_interval_end 94 | parser.add_argument('--eval-interval', type=int, default=1) 95 | parser.add_argument('--eval-interval-end', type=float, default=0.75) 96 | 97 | # one teacher forward, multiple students trained for saving time 98 | # e.g. --kd_loss kd dist --kd_loss_weight 1 2 99 | # then there are 2 settings which are (kd_loss=kd, kd_loss_weight=1) and (kd_loss=dist, kd_loss_weight=2), 100 | # but not 2 * 2 = 4 in total 101 | _nargs_attrs = ['kd_loss', 'kd_loss_weight', 'ori_loss_weight', 'model', 'opt', 'clip_grad', 'lr', 102 | 'weight_decay', 'drop', 'drop_path', 'model_ema_decay', 'smoothing', 'bce_loss'] 103 | 104 | parser.add_argument('--kd-loss', default=['kd'], type=str, nargs='+') 105 | parser.add_argument('--kd-loss-weight', default=[1.], type=float, nargs='+') 106 | parser.add_argument('--ori-loss-weight', default=[1.], type=float, nargs='+') 107 | parser.add_argument('--model', default=['resnet50'], type=str, nargs='+') 108 | parser.add_argument('--opt', default=['sgd'], type=str, nargs='+') 109 | parser.add_argument('--clip-grad', type=float, default=[None], nargs='+') 110 | parser.add_argument('--lr', default=[0.05], type=float, nargs='+') 111 | parser.add_argument('--weight-decay', type=float, default=[2e-5], nargs='+') 112 | parser.add_argument('--drop', type=float, default=[0.0], nargs='+') 113 | parser.add_argument('--drop-path', type=float, default=[None], nargs='+') 114 | parser.add_argument('--model-ema-decay', type=float, default=[0.9998], nargs='+') 115 | parser.add_argument('--smoothing', type=float, default=[0.1], nargs='+') 116 | parser.add_argument('--bce-loss', type=int, default=[0], nargs='+') # 0: disable; others: enable 117 | # --------------------------------------------------------------------------------------- 118 | 119 | # Dataset parameters 120 | parser.add_argument('data_dir', metavar='DIR', 121 | help='path to dataset') 122 | parser.add_argument('--dataset', '-d', metavar='NAME', default='', 123 | help='dataset type (default: ImageFolder/ImageTar if empty)') 124 | parser.add_argument('--train-split', metavar='NAME', default='train', 125 | help='dataset train split (default: train)') 126 | parser.add_argument('--val-split', metavar='NAME', default='validation', 127 | help='dataset validation split (default: validation)') 128 | parser.add_argument('--dataset-download', action='store_true', default=False, 129 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.') 130 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', 131 | help='path to class to idx mapping file (default: "")') 132 | 133 | # Model parameters 134 | parser.add_argument('--pretrained', action='store_true', default=False, 135 | help='Start with pretrained version of specified network (if avail)') 136 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 137 | help='Initialize model from this checkpoint (default: none)') 138 | parser.add_argument('--num-classes', type=int, default=1000, metavar='N', 139 | help='number of label classes (Model default if None)') 140 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 141 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 142 | parser.add_argument('--img-size', type=int, default=None, metavar='N', 143 | help='Image patch size (default: None => model default)') 144 | parser.add_argument('--crop-pct', default=None, type=float, 145 | metavar='N', help='Input image center crop percent (for validation only)') 146 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 147 | help='Override mean pixel value of dataset') 148 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 149 | help='Override std deviation of dataset') 150 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 151 | help='Image resize interpolation type (overrides model)') 152 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', 153 | help='Input batch size for training (default: 128)') 154 | parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', 155 | help='Validation batch size override (default: None)') 156 | parser.add_argument('--channels-last', action='store_true', default=False, 157 | help='Use channels_last memory layout') 158 | parser.add_argument('--torchscript', dest='torchscript', action='store_true', 159 | help='torch.jit.script the full model') 160 | parser.add_argument('--fuser', default='', type=str, 161 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") 162 | 163 | # Optimizer parameters 164 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 165 | help='Optimizer Epsilon (default: None, use opt default)') 166 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 167 | help='Optimizer Betas (default: None, use opt default)') 168 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 169 | help='Optimizer momentum (default: 0.9)') 170 | parser.add_argument('--clip-mode', type=str, default='norm', 171 | help='Gradient clipping mode. One of ("norm", "value", "agc")') 172 | parser.add_argument('--layer-decay', type=float, default=None, 173 | help='layer-wise learning rate decay (default: None)') 174 | 175 | # Learning rate schedule parameters 176 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 177 | help='LR scheduler (default: "step"') 178 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 179 | help='learning rate noise on/off epoch percentages') 180 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 181 | help='learning rate noise limit percent (default: 0.67)') 182 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 183 | help='learning rate noise std-dev (default: 1.0)') 184 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 185 | help='learning rate cycle len multiplier (default: 1.0)') 186 | parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', 187 | help='amount to decay each learning rate cycle (default: 0.5)') 188 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 189 | help='learning rate cycle limit, cycles enabled if > 1') 190 | parser.add_argument('--lr-k-decay', type=float, default=1.0, 191 | help='learning rate k-decay for cosine/poly (default: 1.0)') 192 | parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', 193 | help='warmup learning rate (default: 0.0001)') 194 | parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', 195 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 196 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 197 | help='number of epochs to train (default: 300)') 198 | parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', 199 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') 200 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N', 201 | help='manual epoch number (useful on restarts)') 202 | parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', 203 | help='epochs to warmup LR, if scheduler supports') 204 | parser.add_argument('--decay-epochs', type=float, default=100, metavar='N', 205 | help='epoch interval to decay LR') 206 | parser.add_argument('--cooldown-epochs', type=int, default=0, metavar='N', 207 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 208 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 209 | help='patience epochs for Plateau LR scheduler (default: 10') 210 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 211 | help='LR decay rate (default: 0.1)') 212 | 213 | # Augmentation & regularization parameters 214 | parser.add_argument('--no-aug', action='store_true', default=False, 215 | help='Disable all training augmentation, override other train aug args') 216 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 217 | help='Random resize scale (default: 0.08 1.0)') 218 | parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', 219 | help='Random resize aspect ratio (default: 0.75 1.33)') 220 | parser.add_argument('--hflip', type=float, default=0.5, 221 | help='Horizontal flip training aug probability') 222 | parser.add_argument('--vflip', type=float, default=0., 223 | help='Vertical flip training aug probability') 224 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 225 | help='Color jitter factor (default: 0.4)') 226 | parser.add_argument('--aa', type=str, default=None, metavar='NAME', 227 | help='Use AutoAugment policy. "v0" or "original". (default: None)'), 228 | parser.add_argument('--aug-repeats', type=float, default=0, 229 | help='Number of augmentation repetitions (distributed training only) (default: 0)') 230 | parser.add_argument('--aug-splits', type=int, default=0, 231 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 232 | parser.add_argument('--jsd-loss', action='store_true', default=False, 233 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 234 | parser.add_argument('--bce-target-thresh', type=float, default=None, 235 | help='Threshold for binarizing softened BCE targets (default: None, disabled)') 236 | parser.add_argument('--reprob', type=float, default=0., metavar='PCT', 237 | help='Random erase prob (default: 0.)') 238 | parser.add_argument('--remode', type=str, default='pixel', 239 | help='Random erase mode (default: "pixel")') 240 | parser.add_argument('--recount', type=int, default=1, 241 | help='Random erase count (default: 1)') 242 | parser.add_argument('--resplit', action='store_true', default=False, 243 | help='Do not random erase first (clean) augmentation split') 244 | parser.add_argument('--mixup', type=float, default=0.0, 245 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 246 | parser.add_argument('--cutmix', type=float, default=0.0, 247 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 248 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 249 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 250 | parser.add_argument('--mixup-prob', type=float, default=1.0, 251 | help='Probability of performing mixup or cutmix when either/both is enabled') 252 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 253 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 254 | parser.add_argument('--mixup-mode', type=str, default='batch', 255 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 256 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 257 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 258 | parser.add_argument('--train-interpolation', type=str, default='random', 259 | help='Training interpolation (random, bilinear, bicubic default: "random")') 260 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 261 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 262 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 263 | help='Drop block rate (default: None)') 264 | 265 | # Batch norm parameters (only works with gen_efficientnet based models currently) 266 | parser.add_argument('--bn-momentum', type=float, default=None, 267 | help='BatchNorm momentum override (if not None)') 268 | parser.add_argument('--bn-eps', type=float, default=None, 269 | help='BatchNorm epsilon override (if not None)') 270 | parser.add_argument('--sync-bn', action='store_true', 271 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 272 | parser.add_argument('--dist-bn', type=str, default='reduce', 273 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 274 | parser.add_argument('--split-bn', action='store_true', 275 | help='Enable separate BN layers per augmentation split.') 276 | 277 | # Model Exponential Moving Average 278 | parser.add_argument('--model-ema', action='store_true', default=False, 279 | help='Enable tracking moving average of model weights') 280 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, 281 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 282 | 283 | # Misc 284 | parser.add_argument('--seed', type=int, default=42, metavar='S', 285 | help='random seed (default: 42)') 286 | parser.add_argument('--worker-seeding', type=str, default='all', 287 | help='worker seed mode (default: all)') 288 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 289 | help='how many batches to wait before logging training status') 290 | parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', 291 | help='number of checkpoints to keep (default: 10)') 292 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N', 293 | help='how many training processes to use (default: 4)') 294 | parser.add_argument('--amp', action='store_true', default=False, 295 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 296 | parser.add_argument('--apex-amp', action='store_true', default=False, 297 | help='Use NVIDIA Apex AMP mixed precision') 298 | parser.add_argument('--native-amp', action='store_true', default=False, 299 | help='Use Native Torch AMP mixed precision') 300 | parser.add_argument('--no-ddp-bb', action='store_true', default=False, 301 | help='Force broadcast buffers for native DDP to off.') 302 | parser.add_argument('--pin-mem', action='store_true', default=False, 303 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 304 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 305 | help='disable fast prefetcher') 306 | parser.add_argument('--output', default='', type=str, metavar='PATH', 307 | help='path to output folder (default: none, current dir)') 308 | parser.add_argument('--experiment', default='', type=str, metavar='NAME', 309 | help='name of train experiment, name of sub-folder for output') 310 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 311 | help='Best metric (default: "top1"') 312 | parser.add_argument('--tta', type=int, default=0, metavar='N', 313 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 314 | parser.add_argument("--local_rank", default=0, type=int) 315 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 316 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 317 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 318 | parser.add_argument('--log-wandb', action='store_true', default=False, 319 | help='log training and validation metrics to wandb') 320 | 321 | 322 | def _parse_args(): 323 | # Do we have a config file to parse? 324 | args_config, remaining = config_parser.parse_known_args() 325 | if args_config.config: 326 | with open(args_config.config, 'r') as f: 327 | cfg = yaml.safe_load(f) 328 | parser.set_defaults(**cfg) 329 | 330 | # The main arg parser parses the rest of the args, the usual 331 | # defaults will have been overridden if config file specified. 332 | args = parser.parse_args(remaining) 333 | 334 | # Cache the args as a text string to save them in the output dir later 335 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 336 | return args, args_text 337 | 338 | 339 | def main(): 340 | _logger.parent = None 341 | setup_default_logging(_logger, log_path='train.log') 342 | args, args_text = _parse_args() 343 | 344 | # parse nargs 345 | setting_num = 1 346 | setting_dicts = [dict()] 347 | for attr in _nargs_attrs: 348 | value = getattr(args, attr) 349 | if isinstance(value, list): 350 | if len(value) == 1: 351 | for d in setting_dicts: 352 | d[attr] = value[0] 353 | else: 354 | if setting_num == 1: 355 | setting_num = len(value) 356 | setting_dicts = [deepcopy(setting_dicts[0]) for _ in range(setting_num)] 357 | else: # ensure that args with multiple values have the same length 358 | assert setting_num == len(value) 359 | for i, v in enumerate(value): 360 | setting_dicts[i][attr] = v 361 | else: 362 | for d in setting_dicts: 363 | d[attr] = value 364 | 365 | # merge duplicating settings, only for non-nested dict 366 | setting_dicts = [dict(t) for t in sorted(list({tuple(sorted(d.items())) for d in setting_dicts}))] 367 | 368 | # merge settings with only different 'model_ema_decay' 369 | model_ema_decay_list = [] 370 | assist_dict = dict() 371 | for i, d in enumerate(deepcopy(setting_dicts)): 372 | model_ema_decay_list.append(d['model_ema_decay']) 373 | del d['model_ema_decay'] 374 | h = hash(tuple(sorted(d.items()))) 375 | if h not in assist_dict: 376 | assist_dict[h] = [i] 377 | else: 378 | assist_dict[h].append(i) 379 | 380 | merged_setting_dict_list = [] 381 | for v in assist_dict.values(): 382 | d = setting_dicts[v[0]] 383 | d['model_ema_decay'] = tuple([model_ema_decay_list[index] for index in v]) 384 | merged_setting_dict_list.append(d) 385 | 386 | # update 387 | setting_dicts = merged_setting_dict_list 388 | setting_num = len(setting_dicts) 389 | 390 | logger_list = [_logger] 391 | if setting_num > 1: 392 | if args.local_rank == 0: 393 | _logger.info(f'there are {setting_num} settings in total. creating individual logger for each setting') 394 | 395 | logger_list = [] 396 | for i in range(setting_num): 397 | logger = logging.getLogger(f'train-setting-{i}') 398 | logger.parent = None 399 | setup_default_logging(logger, log_path=f'train-setting-{i}.log') 400 | logger_list.append(logger) 401 | if args.local_rank == 0: 402 | logger.info(f'settings of index {i}: ' + 403 | ', '.join(f"{k}={setting_dicts[i][k]}" for k in setting_dicts[i])) 404 | else: 405 | _logger.info(f'settings: ' + ', '.join(f"{k}={setting_dicts[0][k]}" for k in setting_dicts[0])) 406 | 407 | if args.log_wandb: 408 | if has_wandb: 409 | wandb.init(project=args.experiment, config=args) 410 | else: 411 | _logger.warning("You've requested to log metrics to wandb but package not found. " 412 | "Metrics not being logged to wandb, try `pip install wandb`") 413 | 414 | args.prefetcher = not args.no_prefetcher 415 | args.distributed = False 416 | if 'WORLD_SIZE' in os.environ: 417 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 418 | args.device = 'cuda:0' 419 | args.world_size = 1 420 | args.rank = 0 # global rank 421 | if args.distributed: 422 | assert 'RANK' in os.environ and 'WORLD_SIZE' in os.environ 423 | args.rank = int(os.environ["RANK"]) 424 | args.world_size = int(os.environ['WORLD_SIZE']) 425 | args.device = int(os.environ['LOCAL_RANK']) 426 | 427 | torch.cuda.set_device(args.device) 428 | 429 | torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url, 430 | world_size=args.world_size, rank=args.rank) 431 | torch.distributed.barrier() 432 | 433 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 434 | % (args.rank, args.world_size)) 435 | 436 | else: 437 | _logger.info('Training with a single process on 1 GPUs.') 438 | assert args.rank >= 0 439 | 440 | # resolve AMP arguments based on PyTorch / Apex availability 441 | use_amp = None 442 | if args.amp: 443 | # `--amp` chooses native amp before apex (APEX ver not actively maintained) 444 | if has_native_amp: 445 | args.native_amp = True 446 | elif has_apex: 447 | args.apex_amp = True 448 | if args.apex_amp and has_apex: 449 | use_amp = 'apex' 450 | elif args.native_amp and has_native_amp: 451 | use_amp = 'native' 452 | elif args.apex_amp or args.native_amp: 453 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. " 454 | "Install NVIDA apex or upgrade to PyTorch 1.6") 455 | 456 | random_seed(args.seed, args.rank) 457 | 458 | if args.fuser: 459 | set_jit_fuser(args.fuser) 460 | 461 | teacher = create_model( 462 | args.teacher, 463 | checkpoint_path=args.teacher_pretrained, 464 | num_classes=args.num_classes) 465 | register_forward(teacher) 466 | teacher = teacher.cuda() 467 | teacher.eval() 468 | 469 | model_list = [] 470 | for i in range(setting_num): 471 | model = create_model( 472 | setting_dicts[i]['model'], 473 | pretrained=args.pretrained, 474 | num_classes=args.num_classes, 475 | drop_rate=setting_dicts[i]['drop'], 476 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path 477 | drop_path_rate=setting_dicts[i]['drop_path'], 478 | drop_block_rate=args.drop_block, 479 | global_pool=args.gp, 480 | bn_momentum=args.bn_momentum, 481 | bn_eps=args.bn_eps, 482 | scriptable=args.torchscript, 483 | checkpoint_path=args.initial_checkpoint) 484 | register_forward(model) 485 | model_list.append(model) 486 | 487 | if args.local_rank == 0: 488 | logger_list[i].info(f'Model {safe_model_name(setting_dicts[i]["model"])} created, ' 489 | f'param count:{sum([m.numel() for m in model.parameters()])}') 490 | 491 | # all settings must have the same data config to assure consistent teacher prediction 492 | if args.num_classes is None: 493 | assert hasattr(model_list[0], 494 | 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' 495 | args.num_classes = model_list[0].num_classes # FIXME handle model default vs config num_classes more elegantly 496 | 497 | data_config = resolve_data_config(vars(args), model=model_list[0], verbose=args.local_rank == 0) 498 | 499 | # setup augmentation batch splits for contrastive loss or split bn 500 | num_aug_splits = 0 501 | if args.aug_splits > 0: 502 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 503 | num_aug_splits = args.aug_splits 504 | 505 | for i, model in enumerate(model_list): 506 | # enable split bn (separate bn stats per batch-portion) 507 | if args.split_bn: 508 | assert num_aug_splits > 1 or args.resplit 509 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 510 | 511 | # move model to GPU, enable channels last layout if set 512 | if setting_dicts[i]['kd_loss'] == 'correlation': 513 | kd_loss_fn = Correlation(feat_s_channel=config.get_pre_logit_dim(setting_dicts[i]['model']), 514 | feat_t_channel=config.get_pre_logit_dim(args.teacher)) 515 | elif setting_dicts[i]['kd_loss'] == 'review': 516 | feat_index_s = config.get_used_feature_index(setting_dicts[i]['model']) 517 | feat_index_t = config.get_used_feature_index(args.teacher) 518 | in_channels = [config.get_feature_size_by_index(setting_dicts[i]['model'], j)[0] for j in feat_index_s] 519 | out_channels = [config.get_feature_size_by_index(args.teacher, j)[0] for j in feat_index_t] 520 | in_channels = in_channels + [config.get_pre_logit_dim(setting_dicts[i]['model'])] 521 | out_channels = out_channels + [config.get_pre_logit_dim(args.teacher)] 522 | 523 | kd_loss_fn = ReviewKD(feat_index_s, feat_index_t, in_channels, out_channels) 524 | elif setting_dicts[i]['kd_loss'] == 'rkd': 525 | kd_loss_fn = RKD() 526 | else: 527 | raise NotImplementedError(f'unknown kd loss {args.kd_loss}') 528 | 529 | model.kd_loss_fn = kd_loss_fn 530 | model.cuda() 531 | if args.channels_last: 532 | model = model.to(memory_format=torch.channels_last) 533 | 534 | # setup synchronized BatchNorm for distributed training 535 | if args.distributed and args.sync_bn: 536 | assert not args.split_bn 537 | if has_apex and use_amp == 'apex': 538 | # Apex SyncBN preferred unless native amp is activated 539 | model = convert_syncbn_model(model) 540 | else: 541 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 542 | if args.local_rank == 0: 543 | _logger.info( 544 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 545 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 546 | 547 | if args.torchscript: 548 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' 549 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' 550 | model = torch.jit.script(model) 551 | 552 | optimizer_list = [] 553 | # if setting_num == 1: 554 | # optimizer_list.append(create_optimizer_v2(model_list[0], **optimizer_kwargs(cfg=args))) 555 | # else: 556 | for i in range(setting_num): 557 | optimizer_list.append(create_optimizer_v2(model_list[i], 558 | opt=setting_dicts[i]['opt'], 559 | lr=setting_dicts[i]['lr'], 560 | weight_decay=setting_dicts[i]['weight_decay'], 561 | momentum=args.momentum)) 562 | 563 | # setup automatic mixed-precision (AMP) loss scaling and op casting 564 | amp_autocast = suppress # do nothing 565 | loss_scaler = None 566 | if use_amp == 'apex': 567 | new_model_list = [] 568 | new_optimizer_list = [] 569 | for model, optimizer in zip(model_list, optimizer_list): 570 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 571 | new_model_list.append(model) 572 | new_optimizer_list.append(optimizer) 573 | model_list = new_model_list 574 | optimizer_list = new_optimizer_list 575 | loss_scaler = ApexScaler() 576 | if args.local_rank == 0: 577 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 578 | elif use_amp == 'native': 579 | amp_autocast = torch.cuda.amp.autocast 580 | loss_scaler = NativeScaler() 581 | if args.local_rank == 0: 582 | _logger.info('Using native Torch AMP. Training in mixed precision.') 583 | else: 584 | if args.local_rank == 0: 585 | _logger.info('AMP not enabled. Training in float32.') 586 | 587 | # setup exponential moving average of model weights, SWA could be used here too 588 | model_ema_list = [(None,) for _ in range(setting_num)] 589 | if args.model_ema: 590 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper 591 | for i in range(setting_num): 592 | model_emas = [] 593 | for decay in setting_dicts[i]['model_ema_decay']: 594 | model_ema = ModelEmaV2(model_list[i], decay=decay, 595 | device='cpu' if args.model_ema_force_cpu else None) 596 | model_emas.append((model_ema, decay)) 597 | model_ema_list[i] = tuple(model_emas) 598 | 599 | # setup distributed training 600 | if args.distributed: 601 | new_model_list = [] 602 | for i in range(setting_num): 603 | if has_apex and use_amp == 'apex': 604 | # Apex DDP preferred unless native amp is activated 605 | if args.local_rank == 0: 606 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 607 | model = ApexDDP(model_list[i], delay_allreduce=True) 608 | else: 609 | if args.local_rank == 0: 610 | _logger.info("Using native Torch DistributedDataParallel.") 611 | model = NativeDDP(model_list[i], device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) 612 | # NOTE: EMA model does not need to be wrapped by DDP 613 | new_model_list.append(model) 614 | model_list = new_model_list 615 | 616 | start_epoch = 0 617 | if args.start_epoch is not None: 618 | # a specified start_epoch will always override the resume epoch 619 | start_epoch = args.start_epoch 620 | 621 | # setup learning rate schedule and starting epoch 622 | lr_scheduler_list = [] 623 | for i in range(setting_num): 624 | lr_scheduler, num_epochs = create_scheduler(args, optimizer_list[i]) 625 | if lr_scheduler is not None and start_epoch > 0: 626 | lr_scheduler.step(start_epoch) 627 | 628 | lr_scheduler_list.append(lr_scheduler) 629 | 630 | if args.local_rank == 0: 631 | logger_list[i].info('Scheduled epochs: {}'.format(num_epochs)) 632 | 633 | # create the train and eval datasets 634 | dataset_train = create_dataset( 635 | args.dataset, root=args.data_dir, split=args.train_split, is_training=True, 636 | class_map=args.class_map, 637 | download=args.dataset_download, 638 | batch_size=args.batch_size, 639 | repeats=args.epoch_repeats) 640 | dataset_eval = create_dataset( 641 | args.dataset, root=args.data_dir, split=args.val_split, is_training=False, 642 | class_map=args.class_map, 643 | download=args.dataset_download, 644 | batch_size=args.batch_size) 645 | 646 | # setup mixup / cutmix 647 | # smoothing is implemented in data loader when prefetcher=True, 648 | # so prefetcher should be turned off when multiple smoothing settings are used 649 | smoothing_setting_num = len(set([d['smoothing'] for d in setting_dicts])) 650 | 651 | collate_fn = None 652 | mixup_fn = None 653 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 654 | if mixup_active: 655 | if smoothing_setting_num == 1 and args.prefetcher: 656 | mixup_args = dict( 657 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 658 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 659 | label_smoothing=setting_dicts[0]['smoothing'], num_classes=args.num_classes) 660 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) 661 | collate_fn = FastCollateMixup(**mixup_args) 662 | else: 663 | smoothings = tuple([d['smoothing'] for d in setting_dicts]) 664 | mixup_args = dict( 665 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 666 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 667 | smoothings=smoothings, num_classes=args.num_classes) 668 | mixup_fn = MultiSmoothingMixup(**mixup_args) 669 | 670 | # wrap dataset in AugMix helper 671 | if num_aug_splits > 1: 672 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 673 | 674 | # create data loaders w/ augmentation pipeiine 675 | train_interpolation = args.train_interpolation 676 | if args.no_aug or not train_interpolation: 677 | train_interpolation = data_config['interpolation'] 678 | loader_train = create_loader( 679 | dataset_train, 680 | input_size=data_config['input_size'], 681 | batch_size=args.batch_size, 682 | is_training=True, 683 | use_prefetcher=args.prefetcher, 684 | no_aug=args.no_aug, 685 | re_prob=args.reprob, 686 | re_mode=args.remode, 687 | re_count=args.recount, 688 | re_split=args.resplit, 689 | scale=args.scale, 690 | ratio=args.ratio, 691 | hflip=args.hflip, 692 | vflip=args.vflip, 693 | color_jitter=args.color_jitter, 694 | auto_augment=args.aa, 695 | num_aug_repeats=args.aug_repeats, 696 | num_aug_splits=num_aug_splits, 697 | interpolation=train_interpolation, 698 | mean=data_config['mean'], 699 | std=data_config['std'], 700 | num_workers=args.workers, 701 | distributed=args.distributed, 702 | collate_fn=collate_fn, 703 | pin_memory=args.pin_mem, 704 | use_multi_epochs_loader=args.use_multi_epochs_loader, 705 | worker_seeding=args.worker_seeding, 706 | ) 707 | 708 | loader_eval = create_loader( 709 | dataset_eval, 710 | input_size=(3, 224, 224), 711 | batch_size=args.validation_batch_size or args.batch_size, 712 | is_training=False, 713 | use_prefetcher=args.prefetcher, 714 | interpolation=data_config['interpolation'], 715 | mean=data_config['mean'], 716 | std=data_config['std'], 717 | num_workers=args.workers, 718 | distributed=args.distributed, 719 | crop_pct=data_config['crop_pct'], 720 | pin_memory=args.pin_mem, 721 | ) 722 | 723 | # setup loss function 724 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 725 | 726 | train_loss_fn_list = [] 727 | for i in range(setting_num): 728 | if args.jsd_loss: 729 | assert num_aug_splits > 1 # JSD only valid with aug splits set 730 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=setting_dicts[i]['smoothing']) 731 | elif mixup_active: 732 | # smoothing is handled with mixup target transform which outputs sparse, soft targets 733 | if setting_dicts[i]['bce_loss']: 734 | train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) 735 | else: 736 | train_loss_fn = SoftTargetCrossEntropy() 737 | elif setting_dicts[i]['smoothing']: 738 | if setting_dicts[i]['bce_loss']: 739 | train_loss_fn = BinaryCrossEntropy(smoothing=setting_dicts[i]['smoothing'], 740 | target_threshold=args.bce_target_thresh) 741 | else: 742 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=setting_dicts[i]['smoothing']) 743 | else: 744 | train_loss_fn = nn.CrossEntropyLoss() 745 | train_loss_fn = train_loss_fn.cuda() 746 | train_loss_fn_list.append(train_loss_fn) 747 | 748 | teacher_resizer = student_resizer = None 749 | if args.teacher_resize is not None: 750 | teacher_resizer = transforms.Resize(args.teacher_resize).cuda() 751 | if args.student_resize is not None: 752 | student_resizer = transforms.Resize(args.student_resize).cuda() 753 | 754 | # setup checkpoint saver and eval metric tracking 755 | eval_metric = args.eval_metric 756 | saver_list = [None for _ in range(setting_num)] 757 | ema_saver_list = [tuple([None for _ in range(len(model_ema_list[i]))]) for i in range(setting_num)] 758 | for ema, saver in zip(model_ema_list, ema_saver_list): 759 | assert len(ema) == len(saver) 760 | output_dir_list = [None for _ in range(setting_num)] 761 | for i in range(setting_num): 762 | if args.rank == 0: 763 | if args.experiment: 764 | exp_name = args.experiment + f'-setting-{i}' 765 | else: 766 | exp_name = '-'.join([ 767 | datetime.now().strftime("%Y%m%d-%H%M%S"), 768 | safe_model_name(setting_dicts[i]['model']), 769 | str(data_config['input_size'][-1]), 770 | f'-setting-{i}' 771 | ]) 772 | output_dir = get_outdir(args.output if args.output else './output/train', exp_name) 773 | output_dir_list[i] = output_dir 774 | decreasing = True if eval_metric == 'loss' else False 775 | saver_dir = os.path.join(output_dir, 'checkpoint') 776 | os.makedirs(saver_dir) 777 | saver = CheckpointSaverWithLogger( 778 | logger=logger_list[i], model=model_list[i], optimizer=optimizer_list[i], args=args, 779 | amp_scaler=loss_scaler, checkpoint_dir=saver_dir, recovery_dir=saver_dir, 780 | decreasing=decreasing, max_history=args.checkpoint_hist) 781 | saver_list[i] = saver 782 | if model_ema_list[i][0] is not None: 783 | ema_savers = [] 784 | for ema, decay in model_ema_list[i]: 785 | ema_saver_dir = os.path.join(output_dir, f'ema{decay}_checkpoint') 786 | os.makedirs(ema_saver_dir) 787 | ema_saver = CheckpointSaverWithLogger( 788 | logger=logger_list[i], model=model_list[i], optimizer=optimizer_list[i], args=args, 789 | model_ema=ema, amp_scaler=loss_scaler, checkpoint_dir=ema_saver_dir, 790 | recovery_dir=ema_saver_dir, decreasing=decreasing, max_history=args.checkpoint_hist) 791 | ema_savers.append(ema_saver) 792 | ema_saver_list[i] = tuple(ema_savers) 793 | with open(os.path.join(get_outdir(args.output if args.output else './output/train'), 794 | 'args.yaml'), 'w') as f: 795 | f.write(args_text) 796 | 797 | best_metric_list = [None for _ in range(setting_num)] 798 | best_epoch_list = [None for _ in range(setting_num)] 799 | try: 800 | tp = TimePredictor(num_epochs - start_epoch) 801 | for epoch in range(start_epoch, num_epochs): 802 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): 803 | loader_train.sampler.set_epoch(epoch) 804 | 805 | ori_loss_weight_list = tuple([d['ori_loss_weight'] for d in setting_dicts]) 806 | kd_loss_weight_list = tuple([d['kd_loss_weight'] for d in setting_dicts]) 807 | clip_grad_list = tuple([d['clip_grad'] for d in setting_dicts]) 808 | train_metrics_list = train_one_epoch( 809 | epoch, model_list, teacher, loader_train, optimizer_list, train_loss_fn_list, args, 810 | lr_scheduler_list=lr_scheduler_list, amp_autocast=amp_autocast, 811 | loss_scaler=loss_scaler, model_ema_list=model_ema_list, mixup_fn=mixup_fn, 812 | teacher_resizer=teacher_resizer, student_resizer=student_resizer, 813 | ori_loss_weight_list=ori_loss_weight_list, kd_loss_weight_list=kd_loss_weight_list, 814 | clip_grad_list=clip_grad_list, logger_list=logger_list) 815 | 816 | for i in range(setting_num): 817 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 818 | if args.local_rank == 0: 819 | logger_list[i].info("Distributing BatchNorm running means and vars") 820 | distribute_bn(model_list[i], args.world_size, args.dist_bn == 'reduce') 821 | 822 | is_eval = epoch > int(args.eval_interval_end * args.epochs) or epoch % args.eval_interval == 0 823 | if is_eval: 824 | eval_metrics = validate(model_list[i], loader_eval, validate_loss_fn, args, 825 | logger=logger_list[i], amp_autocast=amp_autocast) 826 | 827 | if saver_list[i] is not None: 828 | # save proper checkpoint with eval metric 829 | save_metric = eval_metrics[eval_metric] 830 | best_metric, best_epoch = saver_list[i].save_checkpoint(epoch, metric=save_metric) 831 | best_metric_list[i] = best_metric 832 | best_epoch_list[i] = best_epoch 833 | 834 | if model_ema_list[i][0] is not None and not args.model_ema_force_cpu: 835 | for j, ((ema, decay), saver) in enumerate(zip(model_ema_list[i], ema_saver_list[i])): 836 | 837 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 838 | distribute_bn(ema, args.world_size, args.dist_bn == 'reduce') 839 | 840 | ema_eval_metrics = validate(ema.module, loader_eval, validate_loss_fn, args, 841 | logger=logger_list[i], amp_autocast=amp_autocast, 842 | log_suffix=f' (EMA {decay:.5f})') 843 | 844 | if saver is not None: 845 | # save proper checkpoint with eval metric 846 | save_metric = ema_eval_metrics[eval_metric] 847 | saver.save_checkpoint(epoch, metric=save_metric) 848 | 849 | if output_dir_list[i] is not None: 850 | update_summary( 851 | epoch, train_metrics_list[i], eval_metrics, os.path.join(output_dir_list[i], 'summary.csv'), 852 | write_header=best_metric_list[i] is None, log_wandb=args.log_wandb and has_wandb) 853 | 854 | metrics = eval_metrics[eval_metric] 855 | else: 856 | metrics = None 857 | 858 | if lr_scheduler_list[i] is not None: 859 | # step LR for next epoch 860 | lr_scheduler_list[i].step(epoch + 1, metrics) 861 | 862 | tp.update() 863 | if args.rank == 0: 864 | print(f'Will finish at {tp.get_pred_text()}') 865 | print(f'Avg running time of latest {len(tp.time_list)} epochs: {np.mean(tp.time_list):.2f}s/ep.') 866 | 867 | except KeyboardInterrupt: 868 | pass 869 | 870 | for i in range(setting_num): 871 | if best_metric_list[i] is not None: 872 | logger_list[i].info('*** Best metric: {0} (epoch {1})'.format(best_metric_list[i], best_epoch_list[i])) 873 | 874 | if args.rank == 0: 875 | if setting_num == 1: 876 | os.system(f'mv train.log {args.output}') 877 | else: 878 | os.system(f'mv train.log {args.output}') 879 | for i in range(setting_num): 880 | os.system(f'mv train-setting-{i}.log {args.output}') 881 | 882 | 883 | def train_one_epoch( 884 | epoch, model_list, teacher, loader, optimizer_list, loss_fn_list, args, 885 | lr_scheduler_list=(None,), amp_autocast=suppress, loss_scaler=None, model_ema_list=(None,), mixup_fn=None, 886 | teacher_resizer=None, student_resizer=None, ori_loss_weight_list=(None,), 887 | kd_loss_weight_list=(None,), clip_grad_list=(None,), logger_list=(None,)): 888 | setting_num = len(model_list) 889 | 890 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 891 | if args.prefetcher and loader.mixup_enabled: 892 | loader.mixup_enabled = False 893 | elif mixup_fn is not None: 894 | mixup_fn.mixup_enabled = False 895 | 896 | second_order_list = [hasattr(o, 'is_second_order') and o.is_second_order for o in optimizer_list] 897 | batch_time_m = AverageMeter() 898 | data_time_m = AverageMeter() 899 | losses_m_list = [AverageMeter() for _ in range(setting_num)] 900 | losses_ori_m_list = [AverageMeter() for _ in range(setting_num)] 901 | losses_kd_m_list = [AverageMeter() for _ in range(setting_num)] 902 | 903 | for model in model_list: 904 | model.train() 905 | 906 | end = time.time() 907 | last_idx = len(loader) - 1 908 | num_updates = epoch * len(loader) 909 | for batch_idx, (input, target) in enumerate(loader): 910 | last_batch = batch_idx == last_idx 911 | data_time_m.update(time.time() - end) 912 | if not args.prefetcher: 913 | input, target = input.cuda(), target.cuda() 914 | if mixup_fn is not None: 915 | input, targets = mixup_fn(input, target) 916 | else: 917 | targets = [target for _ in range(setting_num)] 918 | else: 919 | targets = [target for _ in range(setting_num)] 920 | 921 | if args.channels_last: 922 | input = input.contiguous(memory_format=torch.channels_last) 923 | 924 | # teacher forward 925 | with amp_autocast(): 926 | if teacher_resizer is not None: 927 | teacher_input = teacher_resizer(input) 928 | else: 929 | teacher_input = input 930 | 931 | if args.economic: 932 | torch.cuda.empty_cache() 933 | with torch.no_grad(): 934 | output_t, feat_t = teacher(teacher_input, requires_feat=True) 935 | 936 | if args.economic: 937 | torch.cuda.empty_cache() 938 | 939 | # student forward 940 | for i in range(setting_num): 941 | 942 | if setting_num != 1: # more than 1 model 943 | torch.cuda.empty_cache() 944 | 945 | with amp_autocast(): 946 | if student_resizer is not None: 947 | student_input = student_resizer(input) 948 | else: 949 | student_input = input 950 | 951 | output, feat = model_list[i](student_input, requires_feat=True) 952 | 953 | loss_ori = ori_loss_weight_list[i] * loss_fn_list[i](output, targets[i]) 954 | 955 | try: 956 | kd_loss_fn = model_list[i].module.kd_loss_fn 957 | except AttributeError: 958 | kd_loss_fn = model_list[i].kd_loss_fn 959 | 960 | loss_kd = kd_loss_weight_list[i] * kd_loss_fn(z_s=output, z_t=output_t.detach(), 961 | target=targets[i], 962 | epoch=epoch, 963 | feature_student=process_feat(kd_loss_fn, feat), 964 | feature_teacher=process_feat(kd_loss_fn, feat_t)) 965 | loss = loss_ori + loss_kd 966 | 967 | if not args.distributed: 968 | losses_m_list[i].update(loss.item(), input.size(0)) 969 | losses_ori_m_list[i].update(loss_ori.item(), input.size(0)) 970 | losses_kd_m_list[i].update(loss_kd.item(), input.size(0)) 971 | 972 | optimizer_list[i].zero_grad() 973 | if loss_scaler is not None: 974 | loss_scaler( 975 | loss, optimizer_list[i], 976 | clip_grad=clip_grad_list[i], clip_mode=args.clip_mode, 977 | parameters=model_parameters(model_list[i], exclude_head='agc' in args.clip_mode), 978 | create_graph=second_order_list[i]) 979 | else: 980 | loss.backward(create_graph=second_order_list[i]) 981 | if clip_grad_list[i] is not None: 982 | dispatch_clip_grad( 983 | model_parameters(model_list[i], exclude_head='agc' in args.clip_mode), 984 | value=clip_grad_list[i], mode=args.clip_mode) 985 | optimizer_list[i].step() 986 | 987 | if model_ema_list[i][0] is not None: 988 | for ema, _ in model_ema_list[i]: 989 | ema.update(model_list[i]) 990 | 991 | torch.cuda.synchronize() 992 | batch_time_m.update(time.time() - end) 993 | if last_batch or batch_idx % args.log_interval == 0: 994 | lrl = [param_group['lr'] for param_group in optimizer_list[i].param_groups] 995 | lr = sum(lrl) / len(lrl) 996 | 997 | if args.distributed: 998 | reduced_loss = reduce_tensor(loss.data, args.world_size) 999 | reduced_loss_ori = reduce_tensor(loss_ori.data, args.world_size) 1000 | reduced_loss_kd = reduce_tensor(loss_kd.data, args.world_size) 1001 | losses_m_list[i].update(reduced_loss.item(), input.size(0)) 1002 | losses_ori_m_list[i].update(reduced_loss_ori.item(), input.size(0)) 1003 | losses_kd_m_list[i].update(reduced_loss_kd.item(), input.size(0)) 1004 | 1005 | if args.local_rank == 0: 1006 | logger_list[i].info( 1007 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 1008 | 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' 1009 | 'Loss_ori: {loss_ori.val:#.4g} ({loss_ori.avg:#.3g}) ' 1010 | 'Loss_kd: {loss_kd.val:#.4g} ({loss_kd.avg:#.3g}) ' 1011 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 1012 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 1013 | 'LR: {lr:.3e} ' 1014 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 1015 | epoch, 1016 | batch_idx, len(loader), 1017 | 100. * batch_idx / last_idx, 1018 | loss=losses_m_list[i], 1019 | loss_ori=losses_ori_m_list[i], 1020 | loss_kd=losses_kd_m_list[i], 1021 | batch_time=batch_time_m, 1022 | rate=input.size(0) * args.world_size / batch_time_m.val, 1023 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg, 1024 | lr=lr, 1025 | data_time=data_time_m)) 1026 | 1027 | if setting_num != 1: # more than 1 model 1028 | torch.cuda.empty_cache() 1029 | 1030 | num_updates += 1 1031 | for i in range(setting_num): 1032 | if lr_scheduler_list[i] is not None: 1033 | lr_scheduler_list[i].step_update(num_updates=num_updates, metric=losses_m_list[i].avg) 1034 | 1035 | end = time.time() 1036 | # end for 1037 | 1038 | for i in range(setting_num): 1039 | if hasattr(optimizer_list[i], 'sync_lookahead'): 1040 | optimizer_list[i].sync_lookahead() 1041 | 1042 | return [OrderedDict([('loss', losses_m.avg)]) for losses_m in losses_m_list] 1043 | 1044 | 1045 | def validate(model, loader, loss_fn, args, logger, amp_autocast=suppress, log_suffix=''): 1046 | batch_time_m = AverageMeter() 1047 | losses_m = AverageMeter() 1048 | top1_m = AverageMeter() 1049 | top5_m = AverageMeter() 1050 | 1051 | model.eval() 1052 | 1053 | end = time.time() 1054 | last_idx = len(loader) - 1 1055 | with torch.no_grad(): 1056 | for batch_idx, (input, target) in enumerate(loader): 1057 | last_batch = batch_idx == last_idx 1058 | if not args.prefetcher: 1059 | input = input.cuda() 1060 | target = target.cuda() 1061 | if args.channels_last: 1062 | input = input.contiguous(memory_format=torch.channels_last) 1063 | 1064 | with amp_autocast(): 1065 | output = model(input) 1066 | if isinstance(output, (tuple, list)): 1067 | output = output[0] 1068 | 1069 | # augmentation reduction 1070 | reduce_factor = args.tta 1071 | if reduce_factor > 1: 1072 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) 1073 | target = target[0:target.size(0):reduce_factor] 1074 | 1075 | loss = loss_fn(output, target) 1076 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 1077 | 1078 | if args.distributed: 1079 | reduced_loss = reduce_tensor(loss.data, args.world_size) 1080 | acc1 = reduce_tensor(acc1, args.world_size) 1081 | acc5 = reduce_tensor(acc5, args.world_size) 1082 | else: 1083 | reduced_loss = loss.data 1084 | 1085 | torch.cuda.synchronize() 1086 | 1087 | losses_m.update(reduced_loss.item(), input.size(0)) 1088 | top1_m.update(acc1.item(), output.size(0)) 1089 | top5_m.update(acc5.item(), output.size(0)) 1090 | 1091 | batch_time_m.update(time.time() - end) 1092 | end = time.time() 1093 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): 1094 | log_name = 'Test' + log_suffix 1095 | logger.info( 1096 | '{0}: [{1:>4d}/{2}] ' 1097 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 1098 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 1099 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 1100 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 1101 | log_name, batch_idx, last_idx, batch_time=batch_time_m, 1102 | loss=losses_m, top1=top1_m, top5=top5_m)) 1103 | 1104 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) 1105 | 1106 | return metrics 1107 | 1108 | 1109 | if __name__ == '__main__': 1110 | main() 1111 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from datetime import datetime 5 | 6 | import numpy as np 7 | 8 | from timm.data import ImageDataset 9 | from timm.data.mixup import Mixup, mixup_target 10 | from timm.utils import CheckpointSaver, unwrap_model 11 | 12 | 13 | class ImageNetInstanceSample(ImageDataset): 14 | """: Folder datasets which returns (img, label, index, contrast_index): 15 | """ 16 | 17 | def __init__(self, root, name, class_map, load_bytes, is_sample=False, k=4096, **kwargs): 18 | super().__init__(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs) 19 | self.k = k 20 | self.is_sample = is_sample 21 | if self.is_sample: 22 | print('preparing contrastive data...') 23 | num_classes = 1000 24 | num_samples = len(self.parser) 25 | label = np.zeros(num_samples, dtype=np.int32) 26 | for i in range(num_samples): 27 | _, target = self.parser[i] 28 | label[i] = target 29 | 30 | self.cls_positive = [[] for _ in range(num_classes)] 31 | for i in range(num_samples): 32 | self.cls_positive[label[i]].append(i) 33 | 34 | self.cls_negative = [[] for _ in range(num_classes)] 35 | for i in range(num_classes): 36 | for j in range(num_classes): 37 | if j == i: 38 | continue 39 | self.cls_negative[i].extend(self.cls_positive[j]) 40 | 41 | self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)] 42 | self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)] 43 | print('done.') 44 | 45 | def __getitem__(self, index): 46 | """ 47 | Args: 48 | index (int): Index 49 | Returns: 50 | tuple: (image, target) where target is class_index of the target class. 51 | """ 52 | img, target = super().__getitem__(index) 53 | 54 | if self.is_sample: 55 | # sample contrastive examples 56 | pos_idx = index 57 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True) 58 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 59 | return img, target, index, sample_idx 60 | else: 61 | return img, target, index 62 | 63 | 64 | class MultiSmoothingMixup(Mixup): 65 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, 66 | mode='batch', correct_lam=True, smoothings=(0.1,), num_classes=1000): 67 | super(MultiSmoothingMixup, self).__init__(mixup_alpha, cutmix_alpha, cutmix_minmax, prob, switch_prob, 68 | mode, correct_lam, 0, num_classes) 69 | self.smoothings = smoothings 70 | 71 | def __call__(self, x, target): 72 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 73 | if self.mode == 'elem': 74 | lam = self._mix_elem(x) 75 | elif self.mode == 'pair': 76 | lam = self._mix_pair(x) 77 | else: 78 | lam = self._mix_batch(x) 79 | targets = [] 80 | for smoothing in self.smoothings: 81 | targets.append(mixup_target(target, self.num_classes, lam, smoothing, x.device)) 82 | return x, targets 83 | 84 | 85 | class CheckpointSaverWithLogger(CheckpointSaver): 86 | def __init__( 87 | self, 88 | logger, 89 | model, 90 | optimizer, 91 | args=None, 92 | model_ema=None, 93 | amp_scaler=None, 94 | checkpoint_prefix='checkpoint', 95 | recovery_prefix='recovery', 96 | checkpoint_dir='', 97 | recovery_dir='', 98 | decreasing=False, 99 | max_history=10, 100 | unwrap_fn=unwrap_model): 101 | super(CheckpointSaverWithLogger, self).__init__(model, optimizer, args, model_ema, amp_scaler, 102 | checkpoint_prefix, recovery_prefix, checkpoint_dir, 103 | recovery_dir, decreasing, max_history, unwrap_fn) 104 | self.logger = logger 105 | 106 | def save_checkpoint(self, epoch, metric=None): 107 | assert epoch >= 0 108 | tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) 109 | last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) 110 | self._save(tmp_save_path, epoch, metric) 111 | if os.path.exists(last_save_path): 112 | os.unlink(last_save_path) # required for Windows support. 113 | os.rename(tmp_save_path, last_save_path) 114 | worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None 115 | if (len(self.checkpoint_files) < self.max_history 116 | or metric is None or self.cmp(metric, worst_file[1])): 117 | if len(self.checkpoint_files) >= self.max_history: 118 | self._cleanup_checkpoints(1) 119 | filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension 120 | save_path = os.path.join(self.checkpoint_dir, filename) 121 | os.link(last_save_path, save_path) 122 | self.checkpoint_files.append((save_path, metric)) 123 | self.checkpoint_files = sorted( 124 | self.checkpoint_files, key=lambda x: x[1], 125 | reverse=not self.decreasing) # sort in descending order if a lower metric is not better 126 | 127 | checkpoints_str = "Current checkpoints:\n" 128 | for c in self.checkpoint_files: 129 | checkpoints_str += ' {}\n'.format(c) 130 | self.logger.info(checkpoints_str) 131 | 132 | if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): 133 | self.best_epoch = epoch 134 | self.best_metric = metric 135 | best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension) 136 | if os.path.exists(best_save_path): 137 | os.unlink(best_save_path) 138 | os.link(last_save_path, best_save_path) 139 | 140 | return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) 141 | 142 | def _cleanup_checkpoints(self, trim=0): 143 | trim = min(len(self.checkpoint_files), trim) 144 | delete_index = self.max_history - trim 145 | if delete_index < 0 or len(self.checkpoint_files) <= delete_index: 146 | return 147 | to_delete = self.checkpoint_files[delete_index:] 148 | for d in to_delete: 149 | try: 150 | self.logger.debug("Cleaning checkpoint: {}".format(d)) 151 | os.remove(d[0]) 152 | except Exception as e: 153 | self.logger.error("Exception '{}' while deleting checkpoint".format(e)) 154 | self.checkpoint_files = self.checkpoint_files[:delete_index] 155 | 156 | def save_recovery(self, epoch, batch_idx=0): 157 | assert epoch >= 0 158 | filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension 159 | save_path = os.path.join(self.recovery_dir, filename) 160 | self._save(save_path, epoch) 161 | if os.path.exists(self.last_recovery_file): 162 | try: 163 | self.logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) 164 | os.remove(self.last_recovery_file) 165 | except Exception as e: 166 | self.logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) 167 | self.last_recovery_file = self.curr_recovery_file 168 | self.curr_recovery_file = save_path 169 | 170 | 171 | def setup_default_logging(logger, default_level=logging.INFO, log_path=''): 172 | console_handler = logging.StreamHandler() 173 | console_formatter = logging.Formatter("%(name)15s: %(message)s") 174 | console_handler.setFormatter(console_formatter) 175 | # console_handler.setFormatter(FormatterNoInfo()) 176 | logger.addHandler(console_handler) 177 | logger.setLevel(default_level) 178 | if log_path: 179 | file_handler = logging.FileHandler(log_path) 180 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 181 | file_handler.setFormatter(file_formatter) 182 | logger.addHandler(file_handler) 183 | 184 | 185 | class TimePredictor: 186 | def __init__(self, steps, most_recent=30, drop_first=True): 187 | self.init_time = time.time() 188 | self.steps = steps 189 | self.most_recent = most_recent 190 | self.drop_first = drop_first # drop iter 0 191 | 192 | self.time_list = [] 193 | self.temp_time = self.init_time 194 | 195 | def update(self): 196 | time_interval = time.time() - self.temp_time 197 | self.time_list.append(time_interval) 198 | 199 | if self.drop_first and len(self.time_list) > 1: 200 | self.time_list = self.time_list[1:] 201 | self.drop_first = False 202 | 203 | self.time_list = self.time_list[-self.most_recent:] 204 | self.temp_time = time.time() 205 | 206 | def get_pred_text(self): 207 | single_step_time = np.mean(self.time_list) 208 | end_timestamp = self.init_time + single_step_time * self.steps 209 | return datetime.fromtimestamp(end_timestamp).strftime('%Y-%m-%d %H:%M:%S') 210 | 211 | 212 | def process_feat(distiller, source_feat): 213 | if getattr(distiller, 'pre_act_feat', False): 214 | feat = source_feat[0] 215 | else: 216 | feat = source_feat[1] 217 | return feat 218 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ ImageNet Validation Script 3 | 4 | This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained 5 | models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes 6 | canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit. 7 | 8 | Hacked together by Ross Wightman (https://github.com/rwightman) 9 | """ 10 | import argparse 11 | import os 12 | import csv 13 | import glob 14 | import json 15 | import sys 16 | import time 17 | import logging 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.parallel 21 | from collections import OrderedDict 22 | from contextlib import suppress 23 | 24 | if os.path.exists('./timm'): 25 | sys.path.insert(0, './') 26 | 27 | from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models 28 | from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet 29 | from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging 30 | import models 31 | 32 | has_apex = False 33 | try: 34 | from apex import amp 35 | has_apex = True 36 | except ImportError: 37 | pass 38 | 39 | has_native_amp = False 40 | try: 41 | if getattr(torch.cuda.amp, 'autocast') is not None: 42 | has_native_amp = True 43 | except AttributeError: 44 | pass 45 | 46 | torch.backends.cudnn.benchmark = True 47 | _logger = logging.getLogger('validate') 48 | 49 | 50 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 51 | parser.add_argument('data', metavar='DIR', 52 | help='path to dataset') 53 | parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', 54 | help='model architecture (default: dpn92)') 55 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 56 | help='path to latest checkpoint (default: none)') 57 | parser.add_argument('--use-ema', dest='use_ema', action='store_true', 58 | help='use ema version of weights if present') 59 | 60 | 61 | parser.add_argument('--dataset', '-d', metavar='NAME', default='', 62 | help='dataset type (default: ImageFolder/ImageTar if empty)') 63 | parser.add_argument('--split', metavar='NAME', default='validation', 64 | help='dataset split (default: validation)') 65 | parser.add_argument('--dataset-download', action='store_true', default=False, 66 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.') 67 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 68 | help='number of data loading workers (default: 2)') 69 | parser.add_argument('-b', '--batch-size', default=256, type=int, 70 | metavar='N', help='mini-batch size (default: 256)') 71 | parser.add_argument('--img-size', default=None, type=int, 72 | metavar='N', help='Input image dimension, uses model default if empty') 73 | parser.add_argument('--input-size', default=None, nargs=3, type=int, 74 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') 75 | parser.add_argument('--crop-pct', default=None, type=float, 76 | metavar='N', help='Input image center crop pct') 77 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 78 | help='Override mean pixel value of dataset') 79 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 80 | help='Override std deviation of of dataset') 81 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 82 | help='Image resize interpolation type (overrides model)') 83 | parser.add_argument('--num-classes', type=int, default=1000, 84 | help='Number classes in dataset') 85 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', 86 | help='path to class to idx mapping file (default: "")') 87 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 88 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 89 | parser.add_argument('--log-freq', default=50, type=int, 90 | metavar='N', help='batch logging frequency (default: 10)') 91 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 92 | help='use pre-trained model') 93 | parser.add_argument('--num-gpu', type=int, default=1, 94 | help='Number of GPUS to use') 95 | parser.add_argument('--test-pool', dest='test_pool', action='store_true', 96 | help='enable test time pool') 97 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 98 | help='disable fast prefetcher') 99 | parser.add_argument('--pin-mem', action='store_true', default=False, 100 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 101 | parser.add_argument('--channels-last', action='store_true', default=False, 102 | help='Use channels_last memory layout') 103 | parser.add_argument('--amp', action='store_true', default=False, 104 | help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') 105 | parser.add_argument('--apex-amp', action='store_true', default=False, 106 | help='Use NVIDIA Apex AMP mixed precision') 107 | parser.add_argument('--native-amp', action='store_true', default=False, 108 | help='Use Native Torch AMP mixed precision') 109 | parser.add_argument('--tf-preprocessing', action='store_true', default=False, 110 | help='Use Tensorflow preprocessing pipeline (require CPU TF installed') 111 | parser.add_argument('--torchscript', dest='torchscript', action='store_true', 112 | help='convert model torchscript for inference') 113 | parser.add_argument('--fuser', default='', type=str, 114 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") 115 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 116 | help='Output csv file for validation results (summary)') 117 | parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', 118 | help='Real labels JSON file for imagenet evaluation') 119 | parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', 120 | help='Valid label indices txt file for validation of partial label space') 121 | 122 | 123 | def validate(args): 124 | # might as well try to validate something 125 | args.pretrained = args.pretrained or not args.checkpoint 126 | args.prefetcher = not args.no_prefetcher 127 | amp_autocast = suppress # do nothing 128 | if args.amp: 129 | if has_native_amp: 130 | args.native_amp = True 131 | elif has_apex: 132 | args.apex_amp = True 133 | else: 134 | _logger.warning("Neither APEX or Native Torch AMP is available.") 135 | assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." 136 | if args.native_amp: 137 | amp_autocast = torch.cuda.amp.autocast 138 | _logger.info('Validating in mixed precision with native PyTorch AMP.') 139 | elif args.apex_amp: 140 | _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') 141 | else: 142 | _logger.info('Validating in float32. AMP not enabled.') 143 | 144 | # create model 145 | model = create_model( 146 | args.model, 147 | pretrained=args.pretrained, 148 | num_classes=args.num_classes, 149 | in_chans=3, 150 | global_pool=args.gp, 151 | scriptable=args.torchscript) 152 | if args.num_classes is None: 153 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' 154 | args.num_classes = model.num_classes 155 | 156 | if args.checkpoint: 157 | incompatible_keys = load_checkpoint(model, args.checkpoint, args.use_ema, strict=False) 158 | print(f'incompatible_keys: {incompatible_keys}') 159 | 160 | param_count = sum([m.numel() for m in model.parameters()]) 161 | _logger.info('Model %s created, param count: %d' % (args.model, param_count)) 162 | 163 | data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) 164 | test_time_pool = False 165 | if args.test_pool: 166 | model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) 167 | 168 | if args.torchscript: 169 | torch.jit.optimized_execution(True) 170 | model = torch.jit.script(model) 171 | 172 | model = model.cuda() 173 | if args.apex_amp: 174 | model = amp.initialize(model, opt_level='O1') 175 | 176 | if args.channels_last: 177 | model = model.to(memory_format=torch.channels_last) 178 | 179 | if args.num_gpu > 1: 180 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) 181 | 182 | criterion = nn.CrossEntropyLoss().cuda() 183 | 184 | dataset = create_dataset( 185 | root=args.data, name=args.dataset, split=args.split, 186 | download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map) 187 | 188 | if args.valid_labels: 189 | with open(args.valid_labels, 'r') as f: 190 | valid_labels = {int(line.rstrip()) for line in f} 191 | valid_labels = [i in valid_labels for i in range(args.num_classes)] 192 | else: 193 | valid_labels = None 194 | 195 | if args.real_labels: 196 | real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) 197 | else: 198 | real_labels = None 199 | 200 | crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] 201 | loader = create_loader( 202 | dataset, 203 | input_size=data_config['input_size'], 204 | batch_size=args.batch_size, 205 | use_prefetcher=args.prefetcher, 206 | interpolation=data_config['interpolation'], 207 | mean=data_config['mean'], 208 | std=data_config['std'], 209 | num_workers=args.workers, 210 | crop_pct=crop_pct, 211 | pin_memory=args.pin_mem, 212 | tf_preprocessing=args.tf_preprocessing) 213 | 214 | batch_time = AverageMeter() 215 | losses = AverageMeter() 216 | top1 = AverageMeter() 217 | top5 = AverageMeter() 218 | 219 | model.eval() 220 | with torch.no_grad(): 221 | # warmup, reduce variability of first batch time, especially for comparing torchscript vs non 222 | input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() 223 | if args.channels_last: 224 | input = input.contiguous(memory_format=torch.channels_last) 225 | with amp_autocast(): 226 | model(input) 227 | 228 | end = time.time() 229 | for batch_idx, (input, target) in enumerate(loader): 230 | if args.no_prefetcher: 231 | target = target.cuda() 232 | input = input.cuda() 233 | if args.channels_last: 234 | input = input.contiguous(memory_format=torch.channels_last) 235 | 236 | # compute output 237 | with amp_autocast(): 238 | output = model(input) 239 | 240 | if valid_labels is not None: 241 | output = output[:, valid_labels] 242 | loss = criterion(output, target) 243 | 244 | if real_labels is not None: 245 | real_labels.add_result(output) 246 | 247 | # measure accuracy and record loss 248 | acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) 249 | losses.update(loss.item(), input.size(0)) 250 | top1.update(acc1.item(), input.size(0)) 251 | top5.update(acc5.item(), input.size(0)) 252 | 253 | # measure elapsed time 254 | batch_time.update(time.time() - end) 255 | end = time.time() 256 | 257 | if batch_idx % args.log_freq == 0: 258 | _logger.info( 259 | 'Test: [{0:>4d}/{1}] ' 260 | 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 261 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 262 | 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 263 | 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( 264 | batch_idx, len(loader), batch_time=batch_time, 265 | rate_avg=input.size(0) / batch_time.avg, 266 | loss=losses, top1=top1, top5=top5)) 267 | 268 | if real_labels is not None: 269 | # real labels mode replaces topk values at the end 270 | top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) 271 | else: 272 | top1a, top5a = top1.avg, top5.avg 273 | results = OrderedDict( 274 | model=args.model, 275 | top1=round(top1a, 4), top1_err=round(100 - top1a, 4), 276 | top5=round(top5a, 4), top5_err=round(100 - top5a, 4), 277 | param_count=round(param_count / 1e6, 2), 278 | img_size=data_config['input_size'][-1], 279 | cropt_pct=crop_pct, 280 | interpolation=data_config['interpolation']) 281 | 282 | _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( 283 | results['top1'], results['top1_err'], results['top5'], results['top5_err'])) 284 | 285 | return results 286 | 287 | 288 | def _try_run(args, initial_batch_size): 289 | batch_size = initial_batch_size 290 | results = OrderedDict() 291 | error_str = 'Unknown' 292 | while batch_size >= 1: 293 | args.batch_size = batch_size 294 | torch.cuda.empty_cache() 295 | try: 296 | results = validate(args) 297 | return results 298 | except RuntimeError as e: 299 | error_str = str(e) 300 | if 'channels_last' in error_str: 301 | break 302 | _logger.warning(f'"{error_str}" while running validation. Reducing batch size to {batch_size} for retry.') 303 | batch_size = batch_size // 2 304 | results['error'] = error_str 305 | _logger.error(f'{args.model} failed to validate ({error_str}).') 306 | return results 307 | 308 | 309 | def main(): 310 | setup_default_logging() 311 | args = parser.parse_args() 312 | model_cfgs = [] 313 | model_names = [] 314 | if os.path.isdir(args.checkpoint): 315 | # validate all checkpoints in a path with same model 316 | checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') 317 | checkpoints += glob.glob(args.checkpoint + '/*.pth') 318 | model_names = list_models(args.model) 319 | model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] 320 | else: 321 | if args.model == 'all': 322 | # validate all models in a list of names with pretrained checkpoints 323 | args.pretrained = True 324 | model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino']) 325 | model_cfgs = [(n, '') for n in model_names] 326 | elif not is_model(args.model): 327 | # model name doesn't exist, try as wildcard filter 328 | model_names = list_models(args.model) 329 | model_cfgs = [(n, '') for n in model_names] 330 | 331 | if not model_cfgs and os.path.isfile(args.model): 332 | with open(args.model) as f: 333 | model_names = [line.rstrip() for line in f] 334 | model_cfgs = [(n, None) for n in model_names if n] 335 | 336 | if len(model_cfgs): 337 | results_file = args.results_file or './results-all.csv' 338 | _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) 339 | results = [] 340 | try: 341 | initial_batch_size = args.batch_size 342 | for m, c in model_cfgs: 343 | args.model = m 344 | args.checkpoint = c 345 | r = _try_run(args, initial_batch_size) 346 | if 'error' in r: 347 | continue 348 | if args.checkpoint: 349 | r['checkpoint'] = args.checkpoint 350 | results.append(r) 351 | except KeyboardInterrupt as e: 352 | pass 353 | results = sorted(results, key=lambda x: x['top1'], reverse=True) 354 | if len(results): 355 | write_results(results_file, results) 356 | else: 357 | results = validate(args) 358 | # output results in JSON to stdout w/ delimiter for runner script 359 | print(f'--result\n{json.dumps(results, indent=4)}') 360 | 361 | 362 | def write_results(results_file, results): 363 | with open(results_file, mode='w') as cf: 364 | dw = csv.DictWriter(cf, fieldnames=results[0].keys()) 365 | dw.writeheader() 366 | for r in results: 367 | dw.writerow(r) 368 | cf.flush() 369 | 370 | 371 | if __name__ == '__main__': 372 | main() 373 | --------------------------------------------------------------------------------