├── fig
└── data_scale_bar.png
├── losses
├── __init__.py
├── bkd.py
├── correlation.py
├── crd.py
├── dist.py
├── dkd.py
├── kd.py
├── review.py
└── rkd.py
├── models
├── __init__.py
├── beit.py
└── convnextv2.py
├── object_detection
├── README.md
├── configs
│ ├── convnextv2
│ │ ├── cascade_mask_rcnn_convnextv2_3x_coco.py
│ │ └── mask_rcnn_convnextv2_3x_coco.py
│ └── resnet
│ │ └── mask_rcnn_r50_1x_coco.py
└── mmdet
│ └── models
│ └── backbones
│ ├── __init__.py
│ └── convnextv2.py
├── readme.md
├── register.py
├── requirements.txt
├── train-crd.py
├── train-fd.py
├── train-kd.py
├── utils.py
└── validate.py
/fig/data_scale_bar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hao840/vanillaKD/a27b029f693a13d67b85c61a52e25a85f9414e70/fig/data_scale_bar.png
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | The losses come from DIST and DKD
3 | https://github.com/megvii-research/mdistiller
4 | https://github.com/hunto/DIST_KD
5 |
6 | Modifications by Zhiwei Hao (haozhw@bit.edu.cn) and Jianyuan Guo (jianyuan_guo@outlook.com)
7 | '''
8 |
9 | from .bkd import BinaryKLDiv
10 | from .correlation import Correlation
11 | from .crd import CRD
12 | from .dist import DIST
13 | from .dkd import DKD
14 | from .kd import KLDiv
15 | from .review import ReviewKD
16 | from .rkd import RKD
17 |
--------------------------------------------------------------------------------
/losses/bkd.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class BinaryKLDiv(nn.Module):
6 | def __init__(self):
7 | super(BinaryKLDiv, self).__init__()
8 |
9 | def forward(self, z_s, z_t, **kwargs):
10 | kd_loss = F.binary_cross_entropy_with_logits(z_s, z_t.softmax(1))
11 | return kd_loss
12 |
--------------------------------------------------------------------------------
/losses/correlation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Correlation(nn.Module):
6 | scale = 0.02
7 | def __init__(self, feat_s_channel, feat_t_channel, feat_dim=128):
8 | super(Correlation, self).__init__()
9 | self.embed_s = LinearEmbed(feat_s_channel, feat_dim)
10 | self.embed_t = LinearEmbed(feat_t_channel, feat_dim)
11 |
12 | def forward(self, z_s, z_t, **kwargs):
13 | f_s = self.embed_s(kwargs['feature_student'][-1])
14 | f_t = self.embed_t(kwargs['feature_teacher'][-1])
15 |
16 | delta = torch.abs(f_s - f_t)
17 | kd_loss = self.scale * torch.mean((delta[:-1] * delta[1:]).sum(1))
18 | return kd_loss
19 |
20 |
21 | class LinearEmbed(nn.Module):
22 | def __init__(self, dim_in=1024, dim_out=128):
23 | super(LinearEmbed, self).__init__()
24 | self.linear = nn.Linear(dim_in, dim_out)
25 |
26 | def forward(self, x):
27 | x = x.view(x.shape[0], -1)
28 | x = self.linear(x)
29 | return x
30 |
--------------------------------------------------------------------------------
/losses/crd.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class CRD(nn.Module):
8 | def __init__(self, feat_s_channel, feat_t_channel, feat_dim, num_data, k=16384, momentum=0.5, temperature=0.07):
9 | super(CRD, self).__init__()
10 | self.embed_s = Embed(feat_s_channel, feat_dim)
11 | self.embed_t = Embed(feat_t_channel, feat_dim)
12 | self.contrast = ContrastMemory(feat_dim, num_data, k, temperature, momentum)
13 | self.criterion_s = ContrastLoss(num_data)
14 | self.criterion_t = ContrastLoss(num_data)
15 |
16 | def forward(self, z_s, z_t, **kwargs):
17 | f_s = self.embed_s(kwargs['feature_student'][-1])
18 | f_t = self.embed_t(kwargs['feature_teacher'][-1])
19 | out_s, out_t = self.contrast(f_s, f_t, kwargs['index'], kwargs['contrastive_index'])
20 | s_loss = self.criterion_s(out_s)
21 | t_loss = self.criterion_t(out_t)
22 | kd_loss = s_loss + t_loss
23 | return kd_loss
24 |
25 | def get_learnable_parameters(self):
26 | return (
27 | super().get_learnable_parameters()
28 | + list(self.embed_s.parameters())
29 | + list(self.embed_t.parameters())
30 | )
31 |
32 | def get_extra_parameters(self):
33 | params = (
34 | list(self.embed_s.parameters())
35 | + list(self.embed_t.parameters())
36 | + list(self.contrast.buffers())
37 | )
38 | num_p = 0
39 | for p in params:
40 | num_p += p.numel()
41 | return num_p
42 |
43 |
44 | class Normalize(nn.Module):
45 | """normalization layer"""
46 |
47 | def __init__(self, power=2):
48 | super(Normalize, self).__init__()
49 | self.power = power
50 |
51 | def forward(self, x):
52 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1.0 / self.power)
53 | out = x.div(norm)
54 | return out
55 |
56 |
57 | class Embed(nn.Module):
58 | """Embedding module"""
59 |
60 | def __init__(self, dim_in=1024, dim_out=128):
61 | super(Embed, self).__init__()
62 | self.linear = nn.Linear(dim_in, dim_out)
63 | self.l2norm = Normalize(2)
64 |
65 | def forward(self, x):
66 | x = x.reshape(x.shape[0], -1)
67 | x = self.linear(x)
68 | x = self.l2norm(x)
69 | return x
70 |
71 |
72 | class ContrastLoss(nn.Module):
73 | """contrastive loss"""
74 |
75 | def __init__(self, num_data):
76 | super(ContrastLoss, self).__init__()
77 | self.num_data = num_data
78 |
79 | def forward(self, x):
80 | eps = 1e-7
81 | bsz = x.shape[0]
82 | m = x.size(1) - 1
83 |
84 | # noise distribution
85 | Pn = 1 / float(self.num_data)
86 |
87 | # loss for positive pair
88 | P_pos = x.select(1, 0)
89 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
90 |
91 | # loss for K negative pair
92 | P_neg = x.narrow(1, 1, m)
93 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
94 |
95 | loss = -(log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz
96 |
97 | return loss
98 |
99 |
100 | class ContrastMemory(nn.Module):
101 | """memory buffer that supplies large amount of negative samples."""
102 |
103 | def __init__(self, inputSize, output_size, K, T=0.07, momentum=0.5):
104 | super(ContrastMemory, self).__init__()
105 | self.n_lem = output_size
106 | self.unigrams = torch.ones(self.n_lem)
107 | self.multinomial = AliasMethod(self.unigrams)
108 | self.K = K
109 |
110 | self.register_buffer("params", torch.tensor([K, T, -1, -1, momentum]))
111 | stdv = 1.0 / math.sqrt(inputSize / 3)
112 | self.register_buffer(
113 | "memory_v1", torch.rand(output_size, inputSize).mul_(2 * stdv).add_(-stdv)
114 | )
115 | self.register_buffer(
116 | "memory_v2", torch.rand(output_size, inputSize).mul_(2 * stdv).add_(-stdv)
117 | )
118 |
119 | def cuda(self, *args, **kwargs):
120 | super(ContrastMemory, self).cuda(*args, **kwargs)
121 | self.multinomial.cuda()
122 |
123 | def forward(self, v1, v2, y, idx=None):
124 | K = int(self.params[0].item())
125 | T = self.params[1].item()
126 | Z_v1 = self.params[2].item()
127 | Z_v2 = self.params[3].item()
128 |
129 | momentum = self.params[4].item()
130 | batchSize = v1.size(0)
131 | outputSize = self.memory_v1.size(0)
132 | inputSize = self.memory_v1.size(1)
133 |
134 | # original score computation
135 | if idx is None:
136 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1)
137 | idx.select(1, 0).copy_(y.data)
138 | # sample
139 | weight_v1 = torch.index_select(self.memory_v1, 0, idx.view(-1)).detach()
140 | weight_v1 = weight_v1.view(batchSize, K + 1, inputSize)
141 | out_v2 = torch.bmm(weight_v1, v2.view(batchSize, inputSize, 1))
142 | out_v2 = torch.exp(torch.div(out_v2, T))
143 | # sample
144 | weight_v2 = torch.index_select(self.memory_v2, 0, idx.view(-1)).detach()
145 | weight_v2 = weight_v2.view(batchSize, K + 1, inputSize)
146 | out_v1 = torch.bmm(weight_v2, v1.view(batchSize, inputSize, 1))
147 | out_v1 = torch.exp(torch.div(out_v1, T))
148 |
149 | # set Z if haven't been set yet
150 | if Z_v1 < 0:
151 | self.params[2] = out_v1.mean() * outputSize
152 | Z_v1 = self.params[2].clone().detach().item()
153 | # print("normalization constant Z_v1 is set to {:.1f}".format(Z_v1))
154 | if Z_v2 < 0:
155 | self.params[3] = out_v2.mean() * outputSize
156 | Z_v2 = self.params[3].clone().detach().item()
157 | # print("normalization constant Z_v2 is set to {:.1f}".format(Z_v2))
158 |
159 | # compute out_v1, out_v2
160 | out_v1 = torch.div(out_v1, Z_v1).contiguous()
161 | out_v2 = torch.div(out_v2, Z_v2).contiguous()
162 |
163 | # update memory
164 | with torch.no_grad():
165 | l_pos = torch.index_select(self.memory_v1, 0, y.view(-1))
166 | l_pos.mul_(momentum)
167 | l_pos.add_(torch.mul(v1, 1 - momentum))
168 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5)
169 | updated_v1 = l_pos.div(l_norm)
170 | self.memory_v1.index_copy_(0, y, updated_v1)
171 |
172 | ab_pos = torch.index_select(self.memory_v2, 0, y.view(-1))
173 | ab_pos.mul_(momentum)
174 | ab_pos.add_(torch.mul(v2, 1 - momentum))
175 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5)
176 | updated_v2 = ab_pos.div(ab_norm)
177 | self.memory_v2.index_copy_(0, y, updated_v2)
178 |
179 | return out_v1, out_v2
180 |
181 |
182 | class AliasMethod(object):
183 | """
184 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
185 | """
186 |
187 | def __init__(self, probs):
188 |
189 | if probs.sum() > 1:
190 | probs.div_(probs.sum())
191 | K = len(probs)
192 | self.prob = torch.zeros(K)
193 | self.alias = torch.LongTensor([0] * K)
194 |
195 | # Sort the data into the outcomes with probabilities
196 | # that are larger and smaller than 1/K.
197 | smaller = []
198 | larger = []
199 | for kk, prob in enumerate(probs):
200 | self.prob[kk] = K * prob
201 | if self.prob[kk] < 1.0:
202 | smaller.append(kk)
203 | else:
204 | larger.append(kk)
205 |
206 | # Loop though and create little binary mixtures that
207 | # appropriately allocate the larger outcomes over the
208 | # overall uniform mixture.
209 | while len(smaller) > 0 and len(larger) > 0:
210 | small = smaller.pop()
211 | large = larger.pop()
212 |
213 | self.alias[small] = large
214 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small]
215 |
216 | if self.prob[large] < 1.0:
217 | smaller.append(large)
218 | else:
219 | larger.append(large)
220 |
221 | for last_one in smaller + larger:
222 | self.prob[last_one] = 1
223 |
224 | def cuda(self):
225 | self.prob = self.prob.cuda()
226 | self.alias = self.alias.cuda()
227 |
228 | def draw(self, N):
229 | """Draw N samples from multinomial"""
230 | K = self.alias.size(0)
231 |
232 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K)
233 | prob = self.prob.index_select(0, kk)
234 | alias = self.alias.index_select(0, kk)
235 | # b is whether a random number is greater than q
236 | b = torch.bernoulli(prob)
237 | oq = kk.mul(b.long())
238 | oj = alias.mul((1 - b).long())
239 |
240 | return oq + oj
241 |
--------------------------------------------------------------------------------
/losses/dist.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def cosine_similarity(a, b, eps=1e-8):
5 | return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)
6 |
7 |
8 | def pearson_correlation(a, b, eps=1e-8):
9 | return cosine_similarity(a - a.mean(1).unsqueeze(1),
10 | b - b.mean(1).unsqueeze(1), eps)
11 |
12 |
13 | def inter_class_relation(y_s, y_t):
14 | return 1 - pearson_correlation(y_s, y_t).mean()
15 |
16 |
17 | def intra_class_relation(y_s, y_t):
18 | return inter_class_relation(y_s.transpose(0, 1), y_t.transpose(0, 1))
19 |
20 |
21 | class DIST(nn.Module):
22 | def __init__(self, beta=1.0, gamma=1.0, tau=1.0):
23 | super(DIST, self).__init__()
24 | self.beta = beta
25 | self.gamma = gamma
26 | self.tau = tau
27 |
28 | def forward(self, z_s, z_t, **kwargs):
29 | y_s = (z_s / self.tau).softmax(dim=1)
30 | y_t = (z_t / self.tau).softmax(dim=1)
31 | inter_loss = self.tau ** 2 * inter_class_relation(y_s, y_t)
32 | intra_loss = self.tau ** 2 * intra_class_relation(y_s, y_t)
33 | kd_loss = self.beta * inter_loss + self.gamma * intra_loss
34 | return kd_loss
35 |
--------------------------------------------------------------------------------
/losses/dkd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):
7 | gt_mask = _get_gt_mask(logits_student, target)
8 | other_mask = _get_other_mask(logits_student, target)
9 | pred_student = F.softmax(logits_student / temperature, dim=1)
10 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
11 | pred_student = cat_mask(pred_student, gt_mask, other_mask)
12 | pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
13 | log_pred_student = torch.log(pred_student)
14 | tckd_loss = (
15 | F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')
16 | * (temperature ** 2)
17 | )
18 | pred_teacher_part2 = F.softmax(
19 | logits_teacher / temperature - 1000.0 * gt_mask, dim=1
20 | )
21 | log_pred_student_part2 = F.log_softmax(
22 | logits_student / temperature - 1000.0 * gt_mask, dim=1
23 | )
24 | nckd_loss = (
25 | F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')
26 | * (temperature ** 2)
27 | )
28 | return alpha * tckd_loss + beta * nckd_loss
29 |
30 |
31 | def _get_gt_mask(logits, target):
32 | target = target.reshape(-1)
33 | mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
34 | return mask
35 |
36 |
37 | def _get_other_mask(logits, target):
38 | target = target.reshape(-1)
39 | mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
40 | return mask
41 |
42 |
43 | def cat_mask(t, mask1, mask2):
44 | t1 = (t * mask1).sum(dim=1, keepdims=True)
45 | t2 = (t * mask2).sum(1, keepdims=True)
46 | rt = torch.cat([t1, t2], dim=1)
47 | return rt
48 |
49 |
50 | class DKD(nn.Module):
51 | def __init__(self, alpha=1., beta=2., temperature=1.):
52 | super(DKD, self).__init__()
53 | self.alpha = alpha
54 | self.beta = beta
55 | self.temperature = temperature
56 |
57 | def forward(self, z_s, z_t, **kwargs):
58 | target = kwargs['target']
59 | if len(target.shape) == 2: # mixup / smoothing
60 | target = target.max(1)[1]
61 | kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature)
62 | return kd_loss
63 |
--------------------------------------------------------------------------------
/losses/kd.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class KLDiv(nn.Module):
6 | def __init__(self, temperature=1.0):
7 | super(KLDiv, self).__init__()
8 | self.temperature = temperature
9 |
10 | def forward(self, z_s, z_t, **kwargs):
11 | log_pred_student = F.log_softmax(z_s / self.temperature, dim=1)
12 | pred_teacher = F.softmax(z_t / self.temperature, dim=1)
13 | kd_loss = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
14 | kd_loss *= self.temperature ** 2
15 | return kd_loss
16 |
--------------------------------------------------------------------------------
/losses/review.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ReviewKD(nn.Module):
7 | pre_act_feat = True
8 | def __init__(self, feat_index_s, feat_index_t, in_channels, out_channels,
9 | shapes=(1, 7, 14, 28, 56), out_shapes=(1, 7, 14, 28, 56),
10 | warmup_epochs=1, max_mid_channel=512):
11 | super(ReviewKD, self).__init__()
12 | self.feat_index_s = feat_index_s
13 | self.feat_index_t = feat_index_t
14 | self.shapes = shapes
15 | self.out_shapes = out_shapes
16 | self.warmup_epochs = warmup_epochs
17 |
18 | abfs = nn.ModuleList()
19 | mid_channel = min(max_mid_channel, in_channels[-1])
20 | for idx, in_channel in enumerate(in_channels):
21 | abfs.append(ABF(in_channel, mid_channel, out_channels[idx], idx < len(in_channels) - 1))
22 |
23 | self.abfs = abfs[::-1]
24 |
25 | def forward(self, z_s, z_t, **kwargs):
26 | f_s = [kwargs['feature_student'][i] for i in self.feat_index_s]
27 | pre_logit_feat_s = kwargs['feature_student'][-1]
28 | if len(pre_logit_feat_s.shape) == 2:
29 | pre_logit_feat_s = pre_logit_feat_s.unsqueeze(-1).unsqueeze(-1)
30 | f_s.append(pre_logit_feat_s)
31 |
32 | f_s = f_s[::-1]
33 | results = []
34 | out_features, res_features = self.abfs[0](f_s[0], out_shape=self.out_shapes[0])
35 | results.append(out_features)
36 | for features, abf, shape, out_shape in zip(f_s[1:], self.abfs[1:], self.shapes[1:], self.out_shapes[1:]):
37 | out_features, res_features = abf(features, res_features, shape, out_shape)
38 | results.insert(0, out_features)
39 |
40 | f_t = [kwargs['feature_teacher'][i] for i in self.feat_index_t]
41 | pre_logit_feat_t = kwargs['feature_teacher'][-1]
42 | if len(pre_logit_feat_t.shape) == 2:
43 | pre_logit_feat_t = pre_logit_feat_t.unsqueeze(-1).unsqueeze(-1)
44 | f_t.append(pre_logit_feat_t)
45 |
46 | kd_loss = min(kwargs["epoch"] / self.warmup_epochs, 1.0) * hcl_loss(results, f_t)
47 |
48 | return kd_loss
49 |
50 |
51 | class ABF(nn.Module):
52 | def __init__(self, in_channel, mid_channel, out_channel, fuse):
53 | super(ABF, self).__init__()
54 | self.conv1 = nn.Sequential(
55 | nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),
56 | nn.BatchNorm2d(mid_channel),
57 | )
58 | self.conv2 = nn.Sequential(
59 | nn.Conv2d(
60 | mid_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False
61 | ),
62 | nn.BatchNorm2d(out_channel),
63 | )
64 | if fuse:
65 | self.att_conv = nn.Sequential(
66 | nn.Conv2d(mid_channel * 2, 2, kernel_size=1),
67 | nn.Sigmoid(),
68 | )
69 | else:
70 | self.att_conv = None
71 | nn.init.kaiming_uniform_(self.conv1[0].weight, a=1) # pyre-ignore
72 | nn.init.kaiming_uniform_(self.conv2[0].weight, a=1) # pyre-ignore
73 |
74 | def forward(self, x, y=None, shape=None, out_shape=None):
75 | n, _, h, w = x.shape
76 | # transform student features
77 | x = self.conv1(x)
78 | if self.att_conv is not None:
79 | # upsample residual features
80 | y = F.interpolate(y, (shape, shape), mode="nearest")
81 | # fusion
82 | z = torch.cat([x, y], dim=1)
83 | z = self.att_conv(z)
84 | x = x * z[:, 0].view(n, 1, h, w) + y * z[:, 1].view(n, 1, h, w)
85 | # output
86 | if x.shape[-1] != out_shape:
87 | x = F.interpolate(x, (out_shape, out_shape), mode="nearest")
88 | y = self.conv2(x)
89 | return y, x
90 |
91 |
92 | def hcl_loss(fstudent, fteacher):
93 | loss_all = 0.0
94 | for fs, ft in zip(fstudent, fteacher):
95 | n, c, h, w = fs.shape
96 | loss = F.mse_loss(fs, ft, reduction="mean")
97 | cnt = 1.0
98 | tot = 1.0
99 | for l in [4, 2, 1]:
100 | if l >= h:
101 | continue
102 | tmpfs = F.adaptive_avg_pool2d(fs, (l, l))
103 | tmpft = F.adaptive_avg_pool2d(ft, (l, l))
104 | cnt /= 2.0
105 | loss += F.mse_loss(tmpfs, tmpft, reduction="mean") * cnt
106 | tot += cnt
107 | loss = loss / tot
108 | loss_all = loss_all + loss
109 | return loss_all
110 |
--------------------------------------------------------------------------------
/losses/rkd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def _pdist(e, squared, eps):
7 | e_square = e.pow(2).sum(dim=1)
8 | prod = e @ e.t()
9 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
10 |
11 | if not squared:
12 | res = res.sqrt()
13 |
14 | res = res.clone()
15 | res[range(len(e)), range(len(e))] = 0
16 | return res
17 |
18 |
19 | class RKD(nn.Module):
20 | def __init__(self, distance_weight=25, angle_weight=50, eps=1e-12, squared=False):
21 | super(RKD, self).__init__()
22 | self.distance_weight = distance_weight
23 | self.angle_weight = angle_weight
24 | self.eps = eps
25 | self.squared = squared
26 |
27 | def forward(self, z_s, z_t, **kwargs):
28 | f_s = kwargs['feature_student'][-1]
29 | f_t = kwargs['feature_teacher'][-1]
30 |
31 | stu = f_s.view(f_s.shape[0], -1)
32 | tea = f_t.view(f_t.shape[0], -1)
33 |
34 | # RKD distance loss
35 | with torch.no_grad():
36 | t_d = _pdist(tea, self.squared, self.eps)
37 | mean_td = t_d[t_d > 0].mean()
38 | t_d = t_d / mean_td
39 |
40 | d = _pdist(stu, self.squared, self.eps)
41 | mean_d = d[d > 0].mean()
42 | d = d / mean_d
43 |
44 | loss_d = F.smooth_l1_loss(d, t_d)
45 |
46 | # RKD Angle loss
47 | with torch.no_grad():
48 | td = tea.unsqueeze(0) - tea.unsqueeze(1)
49 | norm_td = F.normalize(td, p=2, dim=2)
50 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
51 |
52 | sd = stu.unsqueeze(0) - stu.unsqueeze(1)
53 | norm_sd = F.normalize(sd, p=2, dim=2)
54 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
55 |
56 | loss_a = F.smooth_l1_loss(s_angle, t_angle)
57 |
58 | kd_loss = self.distance_weight * loss_d + self.angle_weight * loss_a
59 | return kd_loss
60 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .beit import *
2 | from .convnextv2 import *
--------------------------------------------------------------------------------
/models/beit.py:
--------------------------------------------------------------------------------
1 | """ BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
2 |
3 | Model from official source: https://github.com/microsoft/unilm/tree/master/beit
4 |
5 | At this point only the 1k fine-tuned classification weights and model configs have been added,
6 | see original source above for pre-training models and procedure.
7 |
8 | Modifications by / Copyright 2021 Ross Wightman, original copyrights below
9 | """
10 | # --------------------------------------------------------
11 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
12 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
13 | # Copyright (c) 2021 Microsoft
14 | # Licensed under The MIT License [see LICENSE for details]
15 | # By Hangbo Bao
16 | # Based on timm and DeiT code bases
17 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
18 | # https://github.com/facebookresearch/deit/
19 | # https://github.com/facebookresearch/dino
20 | # --------------------------------------------------------'
21 | import math
22 | from functools import partial
23 | from typing import Optional, Tuple
24 |
25 | import torch
26 | import torch.nn as nn
27 | import torch.nn.functional as F
28 | from torch.utils.checkpoint import checkpoint
29 |
30 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
31 | from timm.models.helpers import build_model_with_cfg
32 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
33 | from timm.models.registry import register_model
34 | from timm.models.vision_transformer import checkpoint_filter_fn
35 |
36 |
37 | def _cfg(url='', **kwargs):
38 | return {
39 | 'url': url,
40 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
41 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
42 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
43 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
44 | **kwargs
45 | }
46 |
47 |
48 | default_cfgs = {
49 | 'beit_base_patch16_224': _cfg(
50 | url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
51 | 'beit_base_patch16_384': _cfg(
52 | url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
53 | input_size=(3, 384, 384), crop_pct=1.0,
54 | ),
55 | 'beit_base_patch16_224_in22k': _cfg(
56 | url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth',
57 | num_classes=21841,
58 | ),
59 | 'beit_large_patch16_224': _cfg(
60 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
61 | 'beit_large_patch16_384': _cfg(
62 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
63 | input_size=(3, 384, 384), crop_pct=1.0,
64 | ),
65 | 'beit_large_patch16_512': _cfg(
66 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
67 | input_size=(3, 512, 512), crop_pct=1.0,
68 | ),
69 | 'beit_large_patch16_224_in22k': _cfg(
70 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth',
71 | num_classes=21841,
72 | ),
73 |
74 | 'beitv2_base_patch16_224': _cfg(
75 | url='', input_size=(3, 224, 224), crop_pct=0.9,
76 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
77 | ),
78 | 'beitv2_large_patch16_224': _cfg(
79 | url='', input_size=(3, 224, 224), crop_pct=0.95,
80 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
81 | ),
82 | }
83 |
84 |
85 | def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
86 | num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
87 | # cls to token & token 2 cls & cls to cls
88 | # get pair-wise relative position index for each token inside the window
89 | window_area = window_size[0] * window_size[1]
90 | coords = torch.stack(torch.meshgrid(
91 | [torch.arange(window_size[0]),
92 | torch.arange(window_size[1])])) # 2, Wh, Ww
93 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
94 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
95 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
96 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
97 | relative_coords[:, :, 1] += window_size[1] - 1
98 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1
99 | relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
100 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
101 | relative_position_index[0, 0:] = num_relative_distance - 3
102 | relative_position_index[0:, 0] = num_relative_distance - 2
103 | relative_position_index[0, 0] = num_relative_distance - 1
104 | return relative_position_index
105 |
106 |
107 | class Attention(nn.Module):
108 | def __init__(
109 | self, dim, num_heads=8, qkv_bias=False, attn_drop=0.,
110 | proj_drop=0., window_size=None, attn_head_dim=None):
111 | super().__init__()
112 | self.num_heads = num_heads
113 | head_dim = dim // num_heads
114 | if attn_head_dim is not None:
115 | head_dim = attn_head_dim
116 | all_head_dim = head_dim * self.num_heads
117 | self.scale = head_dim ** -0.5
118 |
119 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
120 | if qkv_bias:
121 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
122 | self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
123 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
124 | else:
125 | self.q_bias = None
126 | self.k_bias = None
127 | self.v_bias = None
128 |
129 | if window_size:
130 | self.window_size = window_size
131 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
132 | self.relative_position_bias_table = nn.Parameter(
133 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
134 | self.register_buffer("relative_position_index", gen_relative_position_index(window_size))
135 | else:
136 | self.window_size = None
137 | self.relative_position_bias_table = None
138 | self.relative_position_index = None
139 |
140 | self.attn_drop = nn.Dropout(attn_drop)
141 | self.proj = nn.Linear(all_head_dim, dim)
142 | self.proj_drop = nn.Dropout(proj_drop)
143 |
144 | def _get_rel_pos_bias(self):
145 | relative_position_bias = self.relative_position_bias_table[
146 | self.relative_position_index.view(-1)].view(
147 | self.window_size[0] * self.window_size[1] + 1,
148 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
149 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
150 | return relative_position_bias.unsqueeze(0)
151 |
152 | def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
153 | B, N, C = x.shape
154 |
155 | qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
156 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
157 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
158 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
159 |
160 | q = q * self.scale
161 | attn = (q @ k.transpose(-2, -1))
162 |
163 | if self.relative_position_bias_table is not None:
164 | attn = attn + self._get_rel_pos_bias()
165 | if shared_rel_pos_bias is not None:
166 | attn = attn + shared_rel_pos_bias
167 |
168 | attn = attn.softmax(dim=-1)
169 | attn = self.attn_drop(attn)
170 |
171 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
172 | x = self.proj(x)
173 | x = self.proj_drop(x)
174 | return x
175 |
176 |
177 | class Block(nn.Module):
178 |
179 | def __init__(
180 | self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
181 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
182 | window_size=None, attn_head_dim=None):
183 | super().__init__()
184 | self.norm1 = norm_layer(dim)
185 | self.attn = Attention(
186 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
187 | window_size=window_size, attn_head_dim=attn_head_dim)
188 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
189 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
190 | self.norm2 = norm_layer(dim)
191 | mlp_hidden_dim = int(dim * mlp_ratio)
192 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
193 |
194 | if init_values:
195 | self.gamma_1 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True)
196 | self.gamma_2 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True)
197 | else:
198 | self.gamma_1, self.gamma_2 = None, None
199 |
200 | def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
201 | if self.gamma_1 is None:
202 | x = x + self.drop_path(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
203 | x = x + self.drop_path(self.mlp(self.norm2(x)))
204 | else:
205 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
206 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
207 | return x
208 |
209 |
210 | class RelativePositionBias(nn.Module):
211 |
212 | def __init__(self, window_size, num_heads):
213 | super().__init__()
214 | self.window_size = window_size
215 | self.window_area = window_size[0] * window_size[1]
216 | num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
217 | self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
218 | # trunc_normal_(self.relative_position_bias_table, std=.02)
219 | self.register_buffer("relative_position_index", gen_relative_position_index(window_size))
220 |
221 | def forward(self):
222 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
223 | self.window_area + 1, self.window_area + 1, -1) # Wh*Ww,Wh*Ww,nH
224 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
225 |
226 |
227 | class Beit(nn.Module):
228 | """ Vision Transformer with support for patch or hybrid CNN input stage
229 | """
230 |
231 | def __init__(
232 | self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
233 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.,
234 | attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
235 | init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
236 | head_init_scale=0.001):
237 | super().__init__()
238 | self.num_classes = num_classes
239 | self.global_pool = global_pool
240 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
241 | self.grad_checkpointing = False
242 |
243 | self.patch_embed = PatchEmbed(
244 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
245 | num_patches = self.patch_embed.num_patches
246 |
247 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
248 | # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
249 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
250 | self.pos_drop = nn.Dropout(p=drop_rate)
251 |
252 | if use_shared_rel_pos_bias:
253 | self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.grid_size, num_heads=num_heads)
254 | else:
255 | self.rel_pos_bias = None
256 |
257 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
258 | self.blocks = nn.ModuleList([
259 | Block(
260 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
261 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
262 | init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None)
263 | for i in range(depth)])
264 | use_fc_norm = self.global_pool == 'avg'
265 | self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
266 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else None
267 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
268 |
269 | self.apply(self._init_weights)
270 | if self.pos_embed is not None:
271 | trunc_normal_(self.pos_embed, std=.02)
272 | trunc_normal_(self.cls_token, std=.02)
273 | # trunc_normal_(self.mask_token, std=.02)
274 | self.fix_init_weight()
275 | if isinstance(self.head, nn.Linear):
276 | trunc_normal_(self.head.weight, std=.02)
277 | self.head.weight.data.mul_(head_init_scale)
278 | self.head.bias.data.mul_(head_init_scale)
279 |
280 | def fix_init_weight(self):
281 | def rescale(param, layer_id):
282 | param.div_(math.sqrt(2.0 * layer_id))
283 |
284 | for layer_id, layer in enumerate(self.blocks):
285 | rescale(layer.attn.proj.weight.data, layer_id + 1)
286 | rescale(layer.mlp.fc2.weight.data, layer_id + 1)
287 |
288 | def _init_weights(self, m):
289 | if isinstance(m, nn.Linear):
290 | trunc_normal_(m.weight, std=.02)
291 | if isinstance(m, nn.Linear) and m.bias is not None:
292 | nn.init.constant_(m.bias, 0)
293 | elif isinstance(m, nn.LayerNorm):
294 | nn.init.constant_(m.bias, 0)
295 | nn.init.constant_(m.weight, 1.0)
296 |
297 | @torch.jit.ignore
298 | def no_weight_decay(self):
299 | nwd = {'pos_embed', 'cls_token'}
300 | for n, _ in self.named_parameters():
301 | if 'relative_position_bias_table' in n:
302 | nwd.add(n)
303 | return nwd
304 |
305 | @torch.jit.ignore
306 | def set_grad_checkpointing(self, enable=True):
307 | self.grad_checkpointing = enable
308 |
309 | @torch.jit.ignore
310 | def group_matcher(self, coarse=False):
311 | matcher = dict(
312 | stem=r'^cls_token|pos_embed|patch_embed|rel_pos_bias', # stem and embed
313 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
314 | )
315 | return matcher
316 |
317 | @torch.jit.ignore
318 | def get_classifier(self):
319 | return self.head
320 |
321 | def reset_classifier(self, num_classes, global_pool=None):
322 | self.num_classes = num_classes
323 | if global_pool is not None:
324 | self.global_pool = global_pool
325 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
326 |
327 | def forward_features(self, x):
328 | x = self.patch_embed(x)
329 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
330 | if self.pos_embed is not None:
331 | x = x + self.pos_embed
332 | x = self.pos_drop(x)
333 |
334 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
335 | for blk in self.blocks:
336 | if self.grad_checkpointing and not torch.jit.is_scripting():
337 | x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
338 | else:
339 | x = blk(x, shared_rel_pos_bias=rel_pos_bias)
340 | x = self.norm(x)
341 | return x
342 |
343 | def forward_head(self, x, pre_logits: bool = False):
344 | if self.fc_norm is not None:
345 | x = x[:, 1:].mean(dim=1)
346 | x = self.fc_norm(x)
347 | else:
348 | x = x[:, 0]
349 | return x if pre_logits else self.head(x)
350 |
351 | def forward(self, x):
352 | x = self.forward_features(x)
353 | x = self.forward_head(x)
354 | return x
355 |
356 |
357 | def _create_beit(variant, pretrained=False, **kwargs):
358 | if kwargs.get('features_only', None):
359 | raise RuntimeError('features_only not implemented for Beit models.')
360 |
361 | model = build_model_with_cfg(
362 | Beit, variant, pretrained,
363 | # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
364 | pretrained_filter_fn=checkpoint_filter_fn,
365 | **kwargs)
366 | return model
367 |
368 |
369 | @register_model
370 | def beit_base_patch16_224(pretrained=False, **kwargs):
371 | model_kwargs = dict(
372 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
373 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
374 | model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **model_kwargs)
375 | return model
376 |
377 |
378 | @register_model
379 | def beit_base_patch16_384(pretrained=False, **kwargs):
380 | model_kwargs = dict(
381 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
382 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
383 | model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **model_kwargs)
384 | return model
385 |
386 |
387 | @register_model
388 | def beit_base_patch16_224_in22k(pretrained=False, **kwargs):
389 | model_kwargs = dict(
390 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
391 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
392 | model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
393 | return model
394 |
395 |
396 | @register_model
397 | def beit_large_patch16_224(pretrained=False, **kwargs):
398 | model_kwargs = dict(
399 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
400 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
401 | model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **model_kwargs)
402 | return model
403 |
404 |
405 | @register_model
406 | def beit_large_patch16_384(pretrained=False, **kwargs):
407 | model_kwargs = dict(
408 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
409 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
410 | model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **model_kwargs)
411 | return model
412 |
413 |
414 | @register_model
415 | def beit_large_patch16_512(pretrained=False, **kwargs):
416 | model_kwargs = dict(
417 | img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
418 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
419 | model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **model_kwargs)
420 | return model
421 |
422 |
423 | @register_model
424 | def beit_large_patch16_224_in22k(pretrained=False, **kwargs):
425 | model_kwargs = dict(
426 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
427 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
428 | model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
429 | return model
430 |
431 |
432 | # models with suffix "_1k" are duplicated to load different checkpoint
433 | @register_model
434 | def beitv2_base_patch16_224_1k(pretrained=False, **kwargs):
435 | model_kwargs = dict(
436 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
437 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
438 | model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs)
439 | return model
440 |
441 |
442 | @register_model
443 | def beitv2_large_patch16_224_1k(pretrained=False, **kwargs):
444 | model_kwargs = dict(
445 | patch_size=16, embed_dim=1024, depth=24, num_heads=16,
446 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
447 | model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs)
448 | return model
449 |
450 |
451 | @register_model
452 | def beitv2_base_patch16_224(pretrained=False, **kwargs):
453 | model_kwargs = dict(
454 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
455 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
456 | model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs)
457 | return model
458 |
459 |
460 | @register_model
461 | def beitv2_large_patch16_224(pretrained=False, **kwargs):
462 | model_kwargs = dict(
463 | patch_size=16, embed_dim=1024, depth=24, num_heads=16,
464 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
465 | model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs)
466 | return model
--------------------------------------------------------------------------------
/models/convnextv2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | # from MinkowskiEngine import SparseTensor
10 | # Copyright (c) Meta Platforms, Inc. and affiliates.
11 |
12 | # All rights reserved.
13 |
14 | # This source code is licensed under the license found in the
15 | # LICENSE file in the root directory of this source tree.
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 | from timm.models.layers import DropPath, trunc_normal_
22 | from timm.models.registry import register_model
23 |
24 |
25 | class LayerNorm(nn.Module):
26 | """ LayerNorm that supports two data formats: channels_last (default) or channels_first.
27 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
28 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
29 | with shape (batch_size, channels, height, width).
30 | """
31 |
32 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(normalized_shape))
35 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
36 | self.eps = eps
37 | self.data_format = data_format
38 | if self.data_format not in ["channels_last", "channels_first"]:
39 | raise NotImplementedError
40 | self.normalized_shape = (normalized_shape,)
41 |
42 | def forward(self, x):
43 | if self.data_format == "channels_last":
44 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
45 | elif self.data_format == "channels_first":
46 | u = x.mean(1, keepdim=True)
47 | s = (x - u).pow(2).mean(1, keepdim=True)
48 | x = (x - u) / torch.sqrt(s + self.eps)
49 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
50 | return x
51 |
52 |
53 | class GRN(nn.Module):
54 | """ GRN (Global Response Normalization) layer
55 | """
56 |
57 | def __init__(self, dim):
58 | super().__init__()
59 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
60 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
61 |
62 | def forward(self, x):
63 | Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
64 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
65 | return self.gamma * (x * Nx) + self.beta + x
66 |
67 |
68 | class Block(nn.Module):
69 | """ ConvNeXtV2 Block.
70 |
71 | Args:
72 | dim (int): Number of input channels.
73 | drop_path (float): Stochastic depth rate. Default: 0.0
74 | """
75 |
76 | def __init__(self, dim, drop_path=0.):
77 | super().__init__()
78 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
79 | self.norm = LayerNorm(dim, eps=1e-6)
80 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
81 | self.act = nn.GELU()
82 | self.grn = GRN(4 * dim)
83 | self.pwconv2 = nn.Linear(4 * dim, dim)
84 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
85 |
86 | def forward(self, x):
87 | input = x
88 | x = self.dwconv(x)
89 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
90 | x = self.norm(x)
91 | x = self.pwconv1(x)
92 | x = self.act(x)
93 | x = self.grn(x)
94 | x = self.pwconv2(x)
95 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
96 |
97 | x = input + self.drop_path(x)
98 | return x
99 |
100 |
101 | class ConvNeXtV2(nn.Module):
102 | """ ConvNeXt V2
103 |
104 | Args:
105 | in_chans (int): Number of input image channels. Default: 3
106 | num_classes (int): Number of classes for classification head. Default: 1000
107 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
108 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
109 | drop_path_rate (float): Stochastic depth rate. Default: 0.
110 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
111 | """
112 |
113 | def __init__(self, in_chans=3, num_classes=1000,
114 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
115 | drop_path_rate=0., head_init_scale=1., **kwargs
116 | ):
117 | super().__init__()
118 | self.depths = depths
119 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
120 | stem = nn.Sequential(
121 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
122 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
123 | )
124 | self.downsample_layers.append(stem)
125 | for i in range(3):
126 | downsample_layer = nn.Sequential(
127 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
128 | nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
129 | )
130 | self.downsample_layers.append(downsample_layer)
131 |
132 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
133 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
134 | cur = 0
135 | for i in range(4):
136 | stage = nn.Sequential(
137 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
138 | )
139 | self.stages.append(stage)
140 | cur += depths[i]
141 |
142 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
143 | self.head = nn.Linear(dims[-1], num_classes)
144 |
145 | self.apply(self._init_weights)
146 | self.head.weight.data.mul_(head_init_scale)
147 | self.head.bias.data.mul_(head_init_scale)
148 |
149 | def _init_weights(self, m):
150 | if isinstance(m, (nn.Conv2d, nn.Linear)):
151 | trunc_normal_(m.weight, std=.02)
152 | nn.init.constant_(m.bias, 0)
153 |
154 | def forward_features(self, x):
155 | for i in range(4):
156 | x = self.downsample_layers[i](x)
157 | x = self.stages[i](x)
158 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
159 |
160 | def forward(self, x):
161 | x = self.forward_features(x)
162 | x = self.head(x)
163 | return x
164 |
165 |
166 | @register_model
167 | def convnextv2_atto(**kwargs):
168 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs)
169 | return model
170 |
171 |
172 | @register_model
173 | def convnextv2_femto(**kwargs):
174 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs)
175 | return model
176 |
177 |
178 | @register_model
179 | def convnext_pico(**kwargs):
180 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs)
181 | return model
182 |
183 |
184 | @register_model
185 | def convnextv2_nano(**kwargs):
186 | model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs)
187 | return model
188 |
189 |
190 | @register_model
191 | def convnextv2_tiny(**kwargs):
192 | model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
193 | return model
194 |
195 |
196 | @register_model
197 | def convnextv2_base(**kwargs):
198 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
199 | return model
200 |
201 |
202 | @register_model
203 | def convnextv2_large(**kwargs):
204 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
205 | return model
206 |
207 |
208 | @register_model
209 | def convnextv2_huge(**kwargs):
210 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs)
211 | return model
212 |
--------------------------------------------------------------------------------
/object_detection/README.md:
--------------------------------------------------------------------------------
1 | # COCO Object detection
2 |
3 | ## Getting started
4 |
5 | We add VanillaNet model and config files based on [mmdetection-2.x](https://github.com/open-mmlab/mmdetection/tree/2.x). Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/2.x/docs/en/get_started.md) for mmdetection installation and dataset preparation instructions.
6 |
7 | ## Results and Fine-tuned Models
8 |
9 | | Framework | Backbone | LR Schedule | APb | APm |
10 | | :---------------: | :----------: | :---------: | :------------: | :------------: |
11 | | Mask RCNN | ResNet50 | 1x | 41.8 | 37.7 |
12 | | | ResNet50 | 2x | 42.1 | 38.0 |
13 | | Mask RCNN | ConvNeXtV2-T | 1x | 45.7 | 42.0 |
14 | | | ConvNeXtV2-T | 3x | 47.9 | 43.3 |
15 | | Cascade Mask RCNN | ConvNeXtV2-T | 1x | 50.6 | 44.3 |
16 | | | ConvNeXtV2-T | 3x | 52.1 | 45.4 |
17 |
18 |
19 | ### Training
20 |
21 | To train a model with 8 gpus, run:
22 | ```
23 | python -m torch.distributed.launch --nproc_per_node=8 tools/train.py --gpus 8 --launcher pytorch --work-dir
24 | ```
25 |
26 |
27 | ## Acknowledgment
28 |
29 | This code is built based on [mmdetection](https://github.com/open-mmlab/mmdetection), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) repositories.
--------------------------------------------------------------------------------
/object_detection/configs/convnextv2/cascade_mask_rcnn_convnextv2_3x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/cascade_mask_rcnn_r50_fpn.py',
3 | '../_base_/datasets/coco_instance.py',
4 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5 | ]
6 |
7 | # you can download ckpt from:
8 | # https://github.com/Hao840/vanillaKD/releases/download/checkpoint/convnextv2_tiny-85.030.pth
9 | checkpoint_file = '/your_path_to/convnextv2_tiny-85.030.pth'
10 |
11 | model = dict(
12 | backbone=dict(
13 | _delete_=True,
14 | type='ConvNeXtV2',
15 | dims=[96, 192, 384, 768],
16 | drop_path_rate=0.4,
17 | layer_scale_init_value=0.,
18 | init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)),
19 | neck=dict(in_channels=[96, 192, 384, 768]),
20 | roi_head=dict(bbox_head=[
21 | dict(
22 | type='ConvFCBBoxHead',
23 | num_shared_convs=4,
24 | num_shared_fcs=1,
25 | in_channels=256,
26 | conv_out_channels=256,
27 | fc_out_channels=1024,
28 | roi_feat_size=7,
29 | num_classes=80,
30 | bbox_coder=dict(
31 | type='DeltaXYWHBBoxCoder',
32 | target_means=[0., 0., 0., 0.],
33 | target_stds=[0.1, 0.1, 0.2, 0.2]),
34 | reg_class_agnostic=False,
35 | reg_decoded_bbox=True,
36 | norm_cfg=dict(type='SyncBN', requires_grad=True),
37 | loss_cls=dict(
38 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
39 | loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
40 | dict(
41 | type='ConvFCBBoxHead',
42 | num_shared_convs=4,
43 | num_shared_fcs=1,
44 | in_channels=256,
45 | conv_out_channels=256,
46 | fc_out_channels=1024,
47 | roi_feat_size=7,
48 | num_classes=80,
49 | bbox_coder=dict(
50 | type='DeltaXYWHBBoxCoder',
51 | target_means=[0., 0., 0., 0.],
52 | target_stds=[0.05, 0.05, 0.1, 0.1]),
53 | reg_class_agnostic=False,
54 | reg_decoded_bbox=True,
55 | norm_cfg=dict(type='SyncBN', requires_grad=True),
56 | loss_cls=dict(
57 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
58 | loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
59 | dict(
60 | type='ConvFCBBoxHead',
61 | num_shared_convs=4,
62 | num_shared_fcs=1,
63 | in_channels=256,
64 | conv_out_channels=256,
65 | fc_out_channels=1024,
66 | roi_feat_size=7,
67 | num_classes=80,
68 | bbox_coder=dict(
69 | type='DeltaXYWHBBoxCoder',
70 | target_means=[0., 0., 0., 0.],
71 | target_stds=[0.033, 0.033, 0.067, 0.067]),
72 | reg_class_agnostic=False,
73 | reg_decoded_bbox=True,
74 | norm_cfg=dict(type='SyncBN', requires_grad=True),
75 | loss_cls=dict(
76 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
77 | loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
78 | ]))
79 |
80 | # dataset settings
81 | dataset_type = 'CocoDataset'
82 | data_root = 'data/coco/'
83 | img_norm_cfg = dict(
84 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
85 | train_pipeline = [
86 | dict(type='LoadImageFromFile'),
87 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
88 | dict(type='RandomFlip', flip_ratio=0.5),
89 | dict(
90 | type='AutoAugment',
91 | policies=[[
92 | dict(
93 | type='Resize',
94 | img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
95 | (608, 1333), (640, 1333), (672, 1333), (704, 1333),
96 | (736, 1333), (768, 1333), (800, 1333)],
97 | multiscale_mode='value',
98 | keep_ratio=True)
99 | ],
100 | [
101 | dict(
102 | type='Resize',
103 | img_scale=[(400, 1333), (500, 1333), (600, 1333)],
104 | multiscale_mode='value',
105 | keep_ratio=True),
106 | dict(
107 | type='RandomCrop',
108 | crop_type='absolute_range',
109 | crop_size=(384, 600),
110 | allow_negative_crop=True),
111 | dict(
112 | type='Resize',
113 | img_scale=[(480, 1333), (512, 1333), (544, 1333),
114 | (576, 1333), (608, 1333), (640, 1333),
115 | (672, 1333), (704, 1333), (736, 1333),
116 | (768, 1333), (800, 1333)],
117 | multiscale_mode='value',
118 | override=True,
119 | keep_ratio=True)
120 | ]]),
121 | dict(type='Normalize', **img_norm_cfg),
122 | dict(type='Pad', size_divisor=32),
123 | dict(type='DefaultFormatBundle'),
124 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
125 | ]
126 | test_pipeline = [
127 | dict(type='LoadImageFromFile'),
128 | dict(
129 | type='MultiScaleFlipAug',
130 | img_scale=(1333, 800),
131 | flip=False,
132 | transforms=[
133 | dict(type='Resize', keep_ratio=True),
134 | dict(type='RandomFlip'),
135 | dict(type='Normalize', **img_norm_cfg),
136 | dict(type='Pad', size_divisor=32),
137 | dict(type='ImageToTensor', keys=['img']),
138 | dict(type='Collect', keys=['img']),
139 | ])
140 | ]
141 | data = dict(
142 | samples_per_gpu=2,
143 | workers_per_gpu=2,
144 | train=dict(
145 | type=dataset_type,
146 | ann_file=data_root + 'annotations/instances_train2017.json',
147 | img_prefix=data_root + 'train2017/',
148 | pipeline=train_pipeline),
149 | val=dict(
150 | type=dataset_type,
151 | ann_file=data_root + 'annotations/instances_val2017.json',
152 | img_prefix=data_root + 'val2017/',
153 | pipeline=test_pipeline),
154 | test=dict(
155 | type=dataset_type,
156 | ann_file=data_root + 'annotations/instances_val2017.json',
157 | img_prefix=data_root + 'val2017/',
158 | pipeline=test_pipeline))
159 | evaluation = dict(metric=['bbox', 'segm'])
160 |
161 | optimizer = dict(
162 | _delete_=True,
163 | constructor='LearningRateDecayOptimizerConstructor',
164 | type='AdamW',
165 | lr=0.0002,
166 | betas=(0.9, 0.999),
167 | weight_decay=0.05,
168 | paramwise_cfg={
169 | 'decay_rate': 0.7,
170 | 'decay_type': 'layer_wise',
171 | 'num_layers': 6
172 | })
173 |
174 | lr_config = dict(warmup_iters=1000, step=[27, 33])
175 | runner = dict(max_epochs=36)
176 |
177 | log_config = dict(
178 | interval=200,
179 | hooks=[
180 | dict(type='TextLoggerHook'),
181 | # dict(type='TensorboardLoggerHook')
182 | ])
183 |
184 | # you need to set mode='dynamic' if you are using pytorch<=1.5.0
185 | fp16 = dict(loss_scale=dict(init_scale=512))
186 |
--------------------------------------------------------------------------------
/object_detection/configs/convnextv2/mask_rcnn_convnextv2_3x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/mask_rcnn_r50_fpn.py',
3 | '../_base_/datasets/coco_instance.py',
4 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5 | ]
6 |
7 | # you can download ckpt from:
8 | # https://github.com/Hao840/vanillaKD/releases/download/checkpoint/convnextv2_tiny-85.030.pth
9 | checkpoint_file = '/your_path_to/convnextv2_tiny-85.030.pth'
10 |
11 | model = dict(
12 | backbone=dict(
13 | _delete_=True,
14 | type='ConvNeXtV2',
15 | dims=[96, 192, 384, 768],
16 | drop_path_rate=0.4,
17 | layer_scale_init_value=0.,
18 | init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)),
19 | neck=dict(in_channels=[96, 192, 384, 768]))
20 |
21 | # dataset settings
22 | dataset_type = 'CocoDataset'
23 | data_root = 'data/coco/'
24 | img_norm_cfg = dict(
25 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
26 | train_pipeline = [
27 | dict(type='LoadImageFromFile'),
28 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
29 | dict(type='RandomFlip', flip_ratio=0.5),
30 | dict(
31 | type='AutoAugment',
32 | policies=[[
33 | dict(
34 | type='Resize',
35 | img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
36 | (608, 1333), (640, 1333), (672, 1333), (704, 1333),
37 | (736, 1333), (768, 1333), (800, 1333)],
38 | multiscale_mode='value',
39 | keep_ratio=True)
40 | ],
41 | [
42 | dict(
43 | type='Resize',
44 | img_scale=[(400, 1333), (500, 1333), (600, 1333)],
45 | multiscale_mode='value',
46 | keep_ratio=True),
47 | dict(
48 | type='RandomCrop',
49 | crop_type='absolute_range',
50 | crop_size=(384, 600),
51 | allow_negative_crop=True),
52 | dict(
53 | type='Resize',
54 | img_scale=[(480, 1333), (512, 1333), (544, 1333),
55 | (576, 1333), (608, 1333), (640, 1333),
56 | (672, 1333), (704, 1333), (736, 1333),
57 | (768, 1333), (800, 1333)],
58 | multiscale_mode='value',
59 | override=True,
60 | keep_ratio=True)
61 | ]]),
62 | dict(type='Normalize', **img_norm_cfg),
63 | dict(type='Pad', size_divisor=32),
64 | dict(type='DefaultFormatBundle'),
65 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
66 | ]
67 | test_pipeline = [
68 | dict(type='LoadImageFromFile'),
69 | dict(
70 | type='MultiScaleFlipAug',
71 | img_scale=(1333, 800),
72 | flip=False,
73 | transforms=[
74 | dict(type='Resize', keep_ratio=True),
75 | dict(type='RandomFlip'),
76 | dict(type='Normalize', **img_norm_cfg),
77 | dict(type='Pad', size_divisor=32),
78 | dict(type='ImageToTensor', keys=['img']),
79 | dict(type='Collect', keys=['img']),
80 | ])
81 | ]
82 | data = dict(
83 | samples_per_gpu=2,
84 | workers_per_gpu=2,
85 | train=dict(
86 | type=dataset_type,
87 | ann_file=data_root + 'annotations/instances_train2017.json',
88 | img_prefix=data_root + 'train2017/',
89 | pipeline=train_pipeline),
90 | val=dict(
91 | type=dataset_type,
92 | ann_file=data_root + 'annotations/instances_val2017.json',
93 | img_prefix=data_root + 'val2017/',
94 | pipeline=test_pipeline),
95 | test=dict(
96 | type=dataset_type,
97 | ann_file=data_root + 'annotations/instances_val2017.json',
98 | img_prefix=data_root + 'val2017/',
99 | pipeline=test_pipeline))
100 | evaluation = dict(metric=['bbox', 'segm'])
101 |
102 | optimizer = dict(
103 | _delete_=True,
104 | constructor='LearningRateDecayOptimizerConstructor',
105 | type='AdamW',
106 | lr=0.0001,
107 | betas=(0.9, 0.999),
108 | weight_decay=0.05,
109 | paramwise_cfg={
110 | 'decay_rate': 0.95,
111 | 'decay_type': 'layer_wise',
112 | 'num_layers': 6
113 | })
114 |
115 | lr_config = dict(warmup_iters=1000, step=[27, 33])
116 | runner = dict(max_epochs=36)
117 |
118 | log_config = dict(
119 | interval=200,
120 | hooks=[
121 | dict(type='TextLoggerHook'),
122 | # dict(type='TensorboardLoggerHook')
123 | ])
124 |
125 | # you need to set mode='dynamic' if you are using pytorch<=1.5.0
126 | fp16 = dict(loss_scale=dict(init_scale=512))
127 |
--------------------------------------------------------------------------------
/object_detection/configs/resnet/mask_rcnn_r50_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/mask_rcnn_r50_fpn.py',
3 | '../_base_/datasets/coco_instance.py',
4 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5 | ]
6 |
7 | # you can download ckpt from:
8 | # https://github.com/Hao840/vanillaKD/releases/download/checkpoint/resnet50-83.078.pth
9 | checkpoint_file = '/your_path_to/resnet50-83.078.pth'
10 |
11 | model = dict(
12 | backbone=dict(
13 | _delete_=True,
14 | type='ResNet',
15 | depth=50,
16 | num_stages=4,
17 | out_indices=(0, 1, 2, 3),
18 | frozen_stages=1,
19 | norm_cfg=dict(type='SyncBN', requires_grad=True),
20 | norm_eval=False,
21 | style='pytorch',
22 | init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)),
23 | )
24 |
25 | # dataset settings
26 | dataset_type = 'CocoDataset'
27 | data_root = 'data/coco/'
28 | img_norm_cfg = dict(
29 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
30 | train_pipeline = [
31 | dict(type='LoadImageFromFile'),
32 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
33 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
34 | dict(type='RandomFlip', flip_ratio=0.5),
35 | dict(type='Normalize', **img_norm_cfg),
36 | dict(type='Pad', size_divisor=32),
37 | dict(type='DefaultFormatBundle'),
38 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
39 | ]
40 | test_pipeline = [
41 | dict(type='LoadImageFromFile'),
42 | dict(
43 | type='MultiScaleFlipAug',
44 | img_scale=(1333, 800),
45 | flip=False,
46 | transforms=[
47 | dict(type='Resize', keep_ratio=True),
48 | dict(type='RandomFlip'),
49 | dict(type='Normalize', **img_norm_cfg),
50 | dict(type='Pad', size_divisor=32),
51 | dict(type='ImageToTensor', keys=['img']),
52 | dict(type='Collect', keys=['img']),
53 | ])
54 | ]
55 | data = dict(
56 | samples_per_gpu=4,
57 | workers_per_gpu=4,
58 | train=dict(
59 | type=dataset_type,
60 | ann_file=data_root + 'annotations/instances_train2017.json',
61 | img_prefix=data_root + 'train2017/',
62 | pipeline=train_pipeline),
63 | val=dict(
64 | type=dataset_type,
65 | ann_file=data_root + 'annotations/instances_val2017.json',
66 | img_prefix=data_root + 'val2017/',
67 | pipeline=test_pipeline),
68 | test=dict(
69 | type=dataset_type,
70 | ann_file=data_root + 'annotations/instances_val2017.json',
71 | img_prefix=data_root + 'val2017/',
72 | pipeline=test_pipeline))
73 | evaluation = dict(metric=['bbox', 'segm'])
74 |
75 | # optimizer
76 | optimizer = dict(type='SGD', lr=0.04, momentum=0.9, weight_decay=0.0001)
77 | optimizer_config = dict(grad_clip=None)
78 | # learning policy
79 | lr_config = dict(
80 | policy='step',
81 | warmup='linear',
82 | warmup_iters=500,
83 | warmup_ratio=0.001,
84 | step=[8, 11])
85 | runner = dict(type='EpochBasedRunner', max_epochs=12)
86 |
87 | log_config = dict(
88 | interval=200,
89 | hooks=[
90 | dict(type='TextLoggerHook'),
91 | # dict(type='TensorboardLoggerHook')
92 | ])
--------------------------------------------------------------------------------
/object_detection/mmdet/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .darknet import Darknet
2 | from .detectors_resnet import DetectoRS_ResNet
3 | from .detectors_resnext import DetectoRS_ResNeXt
4 | from .hourglass import HourglassNet
5 | from .hrnet import HRNet
6 | from .regnet import RegNet
7 | from .res2net import Res2Net
8 | from .resnest import ResNeSt
9 | from .resnet import ResNet, ResNetV1d
10 | from .resnext import ResNeXt
11 | from .ssd_vgg import SSDVGG
12 | from .trident_resnet import TridentResNet
13 | from .swin_transformer import SwinTransformer
14 | from .convnextv2 import ConvNeXtV2
15 |
16 | __all__ = [
17 | 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
18 | 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet',
19 | 'ResNeSt', 'TridentResNet', 'SwinTransformer', 'ConvNeXtV2'
20 | ]
21 |
--------------------------------------------------------------------------------
/object_detection/mmdet/models/backbones/convnextv2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | # Modified by Jianyuan Guo, Zhiwei Hao
4 |
5 | import os
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import math
10 |
11 | from timm.models.layers import DropPath
12 |
13 | from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
14 | constant_init, normal_init, trunc_normal_init)
15 | from mmcv.cnn.bricks.drop import build_dropout
16 | from mmcv.cnn.utils.weight_init import trunc_normal_
17 | from mmcv.runner import (BaseModule, ModuleList, Sequential, _load_checkpoint,
18 | load_state_dict)
19 |
20 | from ...utils import get_root_logger
21 | from ..builder import BACKBONES
22 |
23 |
24 | class LayerNorm2d(nn.LayerNorm):
25 | """LayerNorm on channels for 2d images.
26 |
27 | Args:
28 | num_channels (int): The number of channels of the input tensor.
29 | eps (float): a value added to the denominator for numerical stability.
30 | Defaults to 1e-5.
31 | elementwise_affine (bool): a boolean value that when set to ``True``,
32 | this module has learnable per-element affine parameters initialized
33 | to ones (for weights) and zeros (for biases). Defaults to True.
34 | """
35 |
36 | def __init__(self, num_channels: int, **kwargs) -> None:
37 | super().__init__(num_channels, **kwargs)
38 | self.num_channels = self.normalized_shape[0]
39 |
40 | def forward(self, x, data_format='channel_first'):
41 | """Forward method.
42 |
43 | Args:
44 | x (torch.Tensor): The input tensor.
45 | data_format (str): The format of the input tensor. If
46 | ``"channel_first"``, the shape of the input tensor should be
47 | (B, C, H, W). If ``"channel_last"``, the shape of the input
48 | tensor should be (B, H, W, C). Defaults to "channel_first".
49 | """
50 | assert x.dim() == 4, 'LayerNorm2d only supports inputs with shape ' \
51 | f'(N, C, H, W), but got tensor with shape {x.shape}'
52 | if data_format == 'channel_last':
53 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias,
54 | self.eps)
55 | elif data_format == 'channel_first':
56 | x = x.permute(0, 2, 3, 1)
57 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias,
58 | self.eps)
59 | # If the output is discontiguous, it may cause some unexpected
60 | # problem in the downstream tasks
61 | x = x.permute(0, 3, 1, 2).contiguous()
62 | return x
63 |
64 | class LayerNorm(nn.Module):
65 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
66 | super().__init__()
67 | self.weight = nn.Parameter(torch.ones(normalized_shape))
68 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
69 | self.eps = eps
70 | self.data_format = data_format
71 | if self.data_format not in ["channels_last", "channels_first"]:
72 | raise NotImplementedError
73 | self.normalized_shape = (normalized_shape, )
74 |
75 | def forward(self, x):
76 | if self.data_format == "channels_last":
77 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
78 | elif self.data_format == "channels_first":
79 | u = x.mean(1, keepdim=True)
80 | s = (x - u).pow(2).mean(1, keepdim=True)
81 | x = (x - u) / torch.sqrt(s + self.eps)
82 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
83 | return x
84 |
85 | class GRN(nn.Module):
86 | def __init__(self, dim):
87 | super().__init__()
88 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
89 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
90 |
91 | def forward(self, x):
92 | Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
93 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
94 | return self.gamma * (x * Nx) + self.beta + x
95 |
96 | class Block(nn.Module):
97 | def __init__(self, dim, drop_path=0.):
98 | super().__init__()
99 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
100 | self.norm = LayerNorm(dim, eps=1e-6)
101 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
102 | self.act = nn.GELU()
103 | self.grn = GRN(4 * dim)
104 | self.pwconv2 = nn.Linear(4 * dim, dim)
105 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
106 |
107 | def forward(self, x):
108 | input = x
109 | x = self.dwconv(x)
110 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
111 | x = self.norm(x)
112 | x = self.pwconv1(x)
113 | x = self.act(x)
114 | x = self.grn(x)
115 | x = self.pwconv2(x)
116 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
117 |
118 | x = input + self.drop_path(x)
119 | return x
120 |
121 |
122 | @BACKBONES.register_module()
123 | class ConvNeXtV2(BaseModule):
124 | def __init__(self, in_chans=3,
125 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
126 | drop_path_rate=0., head_init_scale=1.,
127 | out_indices=[0,1,2,3], init_cfg=None, **kwargs
128 | ):
129 | super().__init__()
130 | self.init_cfg = init_cfg
131 | self.depths = depths
132 | self.out_indices = out_indices
133 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
134 | stem = nn.Sequential(
135 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
136 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
137 | )
138 | self.downsample_layers.append(stem)
139 | for i in range(3):
140 | downsample_layer = nn.Sequential(
141 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
142 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
143 | )
144 | self.downsample_layers.append(downsample_layer)
145 |
146 | self.stages = nn.ModuleList()
147 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
148 | cur = 0
149 | for i in range(4):
150 | stage = nn.Sequential(
151 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
152 | )
153 | self.stages.append(stage)
154 | cur += depths[i]
155 |
156 | for i in out_indices:
157 | layer = LayerNorm2d(dims[i])
158 | layer_name = f'norm{i}'
159 | self.add_module(layer_name, layer)
160 |
161 | def init_weights(self):
162 | if self.init_cfg is None:
163 | logger = get_root_logger()
164 | logger.warn(f'No pre-trained weights for '
165 | f'{self.__class__.__name__}, '
166 | f'training start from scratch')
167 | for m in self.modules():
168 | if isinstance(m, nn.Linear):
169 | trunc_normal_init(m, std=.02, bias=0.)
170 | elif isinstance(m, nn.Conv2d):
171 | fan_out = m.kernel_size[0] * m.kernel_size[
172 | 1] * m.out_channels
173 | fan_out //= m.groups
174 | normal_init(m, 0, math.sqrt(2.0 / fan_out))
175 | else:
176 | state_dict = torch.load(self.init_cfg.checkpoint, map_location='cpu')
177 | if len(state_dict.keys()) <= 10 and 'model' in state_dict.keys():
178 | msg = self.load_state_dict(state_dict['model'], strict=False)
179 | else:
180 | msg = self.load_state_dict(state_dict, strict=False)
181 | print(msg)
182 | print('Successfully load backbone ckpt.')
183 |
184 | def forward(self, x):
185 |
186 | outs = []
187 |
188 | for i in range(4):
189 | x = self.downsample_layers[i](x)
190 | x = self.stages[i](x)
191 | if i in self.out_indices:
192 | norm_layer = getattr(self, f'norm{i}')
193 | outs.append(norm_layer(x))
194 |
195 | return outs
196 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # VanillaKD: Revisit the Power of Vanilla Knowledge Distillation from Small Scale to Large Scale
2 |
3 |
4 |
5 |
6 |
7 | Official PyTorch implementation of **VanillaKD**, from the following paper: \
8 | [VanillaKD: Revisit the Power of Vanilla Knowledge Distillation from Small Scale to Large Scale](https://arxiv.org/abs/2305.15781) \
9 | Zhiwei Hao, Jianyuan Guo, Kai Han, Han Hu, Chang Xu, Yunhe Wang
10 |
11 |
12 | This paper emphasizes the importance of scale in achieving superior results. It reveals that previous KD methods designed solely based on small-scale datasets has underestimated the effectiveness of vanilla KD on large-scale datasets, which is referred as to **small data pitfall**. By incorporating **stronger data augmentation** and **larger datasets**, the performance gap between vanilla KD and other approaches is narrowed:
13 |
14 |
15 |
16 | Without bells and whistles, state-of-the-art results are achieved for ResNet-50, ViT-S, and ConvNeXtV2-T models on ImageNet, showcasing the vanilla KD is elegantly simple but astonishingly effective in large-scale scenarios.
17 |
18 | If you find this project useful in your research, please cite:
19 |
20 | ```
21 | @article{hao2023vanillakd,
22 | title={VanillaKD: Revisit the Power of Vanilla Knowledge Distillation from Small Scale to Large Scale},
23 | author={Hao, Zhiwei and Guo, Jianyuan and Han, Kai and Hu, Han and Xu, Chang and Wang, Yunhe},
24 | journal={arXiv preprint arXiv:2305.15781},
25 | year={2023}
26 | }
27 | ```
28 |
29 | ## Model Zoo
30 |
31 | We provide models trained by vanilla KD on ImageNet.
32 |
33 | | name | acc@1 | acc@5 | model |
34 | |:---:|:---:|:---:|:---:|
35 | |resnet50|83.08|96.35|[model](https://github.com/Hao840/vanillaKD/releases/download/checkpoint/resnet50-83.078.pth)|
36 | |vit_tiny_patch16_224|78.11|94.26|[model](https://github.com/Hao840/vanillaKD/releases/download/checkpoint/vit_tiny_patch16_224-78.106.pth)|
37 | |vit_small_patch16_224|84.33|97.09|[model](https://github.com/Hao840/vanillaKD/releases/download/checkpoint/vit_small_patch16_224-84.328.pth)|
38 | |convnextv2_tiny|85.03|97.44|[model](https://github.com/Hao840/vanillaKD/releases/download/checkpoint/convnextv2_tiny-85.030.pth)|
39 |
40 |
41 | ## Usage
42 | First, clone the repository locally:
43 |
44 | ```
45 | git clone https://github.com/Hao840/vanillaKD.git
46 | ```
47 |
48 | Then, install PyTorch and [timm 0.6.5](https://github.com/huggingface/pytorch-image-models/tree/v0.6.5)
49 |
50 | ```
51 | conda install -c pytorch pytorch torchvision
52 | pip install timm==0.6.5
53 | ```
54 |
55 | Our results are produced with `torch==1.10.2+cu113 torchvision==0.11.3+cu113 timm==0.6.5`. Other versions might also work.
56 |
57 | ### Data preparation
58 |
59 | Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is:
60 |
61 | ```
62 | │path/to/imagenet/
63 | ├──train/
64 | │ ├── n01440764
65 | │ │ ├── n01440764_10026.JPEG
66 | │ │ ├── n01440764_10027.JPEG
67 | │ │ ├── ......
68 | │ ├── ......
69 | ├──val/
70 | │ ├── n01440764
71 | │ │ ├── ILSVRC2012_val_00000293.JPEG
72 | │ │ ├── ILSVRC2012_val_00002138.JPEG
73 | │ │ ├── ......
74 | │ ├── ......
75 | ```
76 |
77 | ### Evaluation
78 |
79 | To evaluate a distilled model on ImageNet val with a single GPU, run:
80 |
81 | ```
82 | python validate.py /path/to/imagenet --model --checkpoint /path/to/checkpoint
83 | ```
84 |
85 |
86 | ### Training
87 |
88 | To train a ResNet50 student using BEiTv2-B teacher on ImageNet on a single node with 8 GPUs, run:
89 |
90 | Strategy A2:
91 |
92 | ```
93 | python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss kd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1
94 | ```
95 |
96 | Strategy A1:
97 |
98 | ```
99 | python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss kd --amp --epochs 600 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.01 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.1 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.2 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1
100 | ```
101 |
102 |
103 |
104 | Commands for reproducing baseline results:
105 |
106 |
107 |
108 | DKD
109 |
110 | Training with ResNet50 student, BEiTv2-B teacher, and strategy A2 for 300 epochs
111 |
112 | ```
113 | python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss dkd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1
114 | ```
115 |
116 |
117 |
118 |
119 |
120 |
121 | DIST
122 |
123 | Training with ResNet50 student, BEiTv2-B teacher, and strategy A2 for 300 epochs
124 |
125 | ```
126 | python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss dist --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1
127 | ```
128 |
129 |
130 |
131 |
132 |
133 |
134 | Correlation
135 |
136 | Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs
137 |
138 | ```
139 | python -m torch.distributed.launch --nproc_per_node=8 train-fd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss correlation --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0
140 | ```
141 |
142 |
143 |
144 |
145 |
146 |
147 | RKD
148 |
149 | Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs
150 |
151 | ```
152 | python -m torch.distributed.launch --nproc_per_node=8 train-fd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss rkd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0
153 | ```
154 |
155 |
156 |
157 |
158 |
159 |
160 | ReviewKD
161 |
162 | Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs
163 |
164 | ```
165 | python -m torch.distributed.launch --nproc_per_node=8 train-fd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss review --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0
166 | ```
167 |
168 |
169 |
170 |
171 |
172 |
173 | CRD
174 |
175 | Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs
176 |
177 | ```
178 | python -m torch.distributed.launch --nproc_per_node=8 train-crd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss crd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0
179 |
180 | ```
181 |
182 |
183 | ## Acknowledgement
184 |
185 | This repository is built using the [timm](https://github.com/rwightman/pytorch-image-models) library, [DKD](https://github.com/megvii-research/mdistiller), [DIST](https://github.com/hunto/DIST_KD), [DeiT](https://github.com/facebookresearch/deit), [BEiT v2](https://github.com/microsoft/unilm/tree/master/beit2), and [ConvNeXt v2](https://github.com/facebookresearch/ConvNeXt-V2) repositories.
186 |
--------------------------------------------------------------------------------
/register.py:
--------------------------------------------------------------------------------
1 | from types import MethodType
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from timm.models.beit import Beit
7 | from timm.models.resnet import ResNet
8 |
9 |
10 | class Config:
11 | _feat_dim = {
12 | 'resnet50': (
13 | (64, 112, 112), (256, 56, 56), (512, 28, 28), (1024, 14, 14), (2048, 7, 7), (2048, None, None)),
14 | 'resnet152': (
15 | (64, 112, 112), (256, 56, 56), (512, 28, 28), (1024, 14, 14), (2048, 7, 7), (2048, None, None)),
16 | 'swin_small_patch4_window7_224': (
17 | (96, 56, 56), (192, 28, 28), (384, 14, 14), (768, 7, 7), (768, 7, 7), (768, None, None)),
18 | 'swin_base_patch4_window7_224': (
19 | (128, 56, 56), (256, 28, 28), (512, 14, 14), (1024, 7, 7), (1024, 7, 7), (1024, None, None)),
20 | 'swin_large_patch4_window7_224': (
21 | (192, 56, 56), (384, 28, 28), (768, 14, 14), (1536, 7, 7), (1536, 7, 7), (1536, None, None)),
22 | 'beitv2_large_patch16_224': (
23 | (64, 56, 56), (64, 56, 56), (256, 28, 28), (1024, 14, 14), (1024, 7, 7), (1024, None, None)),
24 | 'bit_r152x2': (
25 | (128, 112, 112), (512, 56, 56), (1024, 28, 28), (2048, 14, 14), (4096, 7, 7), (4096, 1, 1)),
26 | }
27 |
28 | _kd_feat_index = {
29 | 'resnet50': (1, 2, 3, 4),
30 | 'resnet152': (1, 2, 3, 4),
31 | 'swin_small_patch4_window7_224': (0, 1, 2, 4),
32 | 'swin_base_patch4_window7_224': (0, 1, 2, 4),
33 | 'swin_large_patch4_window7_224': (0, 1, 2, 4),
34 | 'beitv2_large_patch16_224': (1, 2, 3, 4),
35 | 'bit_r152x2': (1, 2, 3, 4),
36 | }
37 |
38 | def get_pre_logit_dim(self, model):
39 | feat_sizes = self._feat_dim[model]
40 | if isinstance(feat_sizes, tuple):
41 | return feat_sizes[-1][0]
42 | else:
43 | return feat_sizes
44 |
45 | def get_used_feature_index(self, model):
46 | index = self._kd_feat_index[model]
47 | if index is None:
48 | raise NotImplementedError(f'undefined feature kd for model {model}')
49 | return index
50 |
51 | def get_feature_size_by_index(self, model, index):
52 | valid_index = self.get_used_feature_index(model)
53 | feat_sizes = self._feat_dim[model]
54 | assert index in valid_index
55 | return feat_sizes[index]
56 |
57 |
58 | config = Config()
59 |
60 |
61 | def register_forward(model): # only resnet have implemented pre_act feat
62 | if isinstance(model, ResNet): # ResNet
63 | model.forward = MethodType(ResNet_forward, model)
64 | model.forward_features = MethodType(ResNet_forward_features, model)
65 | elif isinstance(model, Beit): # Beit
66 | model.forward = MethodType(Beitv2_forward, model)
67 | model.forward_features = MethodType(Beitv2_forward_features, model)
68 | else:
69 | raise NotImplementedError('undefined forward method to get feature, check the exp setting carefully!')
70 |
71 |
72 | def _unpatchify(x, p, remove_token=0):
73 | """
74 | x: (N, L, patch_size**2 *C)
75 | imgs: (N, C, H, W)
76 | """
77 | # p = self.patch_embed.patch_size[0]
78 | x = x[:, remove_token:, :]
79 | h = w = int(x.shape[1] ** .5)
80 | assert h * w == x.shape[1]
81 |
82 | x = x.reshape(shape=(x.shape[0], h, w, p, p, -1))
83 | x = torch.einsum('nhwpqc->nchpwq', x)
84 | imgs = x.reshape(shape=(x.shape[0], -1, h * p, h * p))
85 | return imgs
86 |
87 |
88 | # ResNet
89 | def bottleneck_forward(self, x):
90 | shortcut = x
91 |
92 | x = self.conv1(x)
93 | x = self.bn1(x)
94 | x = self.act1(x)
95 |
96 | x = self.conv2(x)
97 | x = self.bn2(x)
98 | x = self.drop_block(x)
99 | x = self.act2(x)
100 | x = self.aa(x)
101 |
102 | x = self.conv3(x)
103 | x = self.bn3(x)
104 |
105 | if self.se is not None:
106 | x = self.se(x)
107 |
108 | if self.drop_path is not None:
109 | x = self.drop_path(x)
110 |
111 | if self.downsample is not None:
112 | shortcut = self.downsample(shortcut)
113 | x += shortcut
114 | pre_act_x = x
115 | x = self.act3(pre_act_x)
116 |
117 | return x, pre_act_x
118 |
119 |
120 | def ResNet_forward_features(self, x, requires_feat):
121 | pre_act_feat = []
122 | feat = []
123 | x = self.conv1(x)
124 | x = self.bn1(x)
125 | pre_act_feat.append(x)
126 | x = self.act1(x)
127 | feat.append(x)
128 | x = self.maxpool(x)
129 |
130 | for layer in [self.layer1, self.layer2, self.layer3, self.layer4]:
131 | for bottleneck in layer:
132 | x, pre_act_x = bottleneck_forward(bottleneck, x)
133 |
134 | pre_act_feat.append(pre_act_x)
135 | feat.append(x)
136 |
137 | return (x, (pre_act_feat, feat)) if requires_feat else x
138 |
139 |
140 | def ResNet_forward(self, x, requires_feat=False):
141 | if requires_feat:
142 | x, (pre_act_feat, feat) = self.forward_features(x, requires_feat=True)
143 | x = self.forward_head(x, pre_logits=True)
144 | feat.append(x)
145 | pre_act_feat.append(x)
146 | x = self.fc(x)
147 | return x, (pre_act_feat, feat)
148 | else:
149 | x = self.forward_features(x, requires_feat=False)
150 | x = self.forward_head(x)
151 | return x
152 |
153 |
154 | def Beitv2_forward_features(self, x, requires_feat):
155 | x = self.patch_embed(x)
156 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
157 | if self.pos_embed is not None:
158 | x = x + self.pos_embed
159 | x = self.pos_drop(x)
160 |
161 | pre_act_feat = [_unpatchify(x, 4, 1)] # stem
162 | feat = [_unpatchify(x, 4, 1)] # stem
163 |
164 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
165 | for i, blk in enumerate(self.blocks): # fixme: curremt for beitv2L only
166 | x = blk(x, shared_rel_pos_bias=rel_pos_bias)
167 | f = None
168 | if i == 1:
169 | f = _unpatchify(x, 4, 1)
170 | elif i == 3:
171 | f = _unpatchify(x, 2, 1)
172 | elif i == 21:
173 | f = _unpatchify(x, 1, 1)
174 | elif i == 23:
175 | f = F.adaptive_avg_pool2d(_unpatchify(x, 1, 1), (7, 7))
176 | if f is not None:
177 | pre_act_feat.append(f)
178 | feat.append(f)
179 | x = self.norm(x)
180 | return (x, (pre_act_feat, feat)) if requires_feat else x
181 |
182 |
183 | def Beitv2_forward(self, x, requires_feat=False):
184 | if requires_feat:
185 | x, (pre_act_feat, feat) = self.forward_features(x, requires_feat=True)
186 | x = self.forward_head(x, pre_logits=True)
187 | feat.append(x)
188 | pre_act_feat.append(x)
189 | x = self.head(x)
190 | return x, (pre_act_feat, feat)
191 | else:
192 | x = self.forward_features(x, requires_feat=False)
193 | x = self.forward_head(x)
194 | return x
195 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.10.2+cu113
2 | torchvision==0.11.3+cu113
3 | timm==0.6.5
--------------------------------------------------------------------------------
/train-crd.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """ ImageNet Training Script
3 |
4 | This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
5 | training results with some of the latest networks and training techniques. It favours canonical PyTorch
6 | and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
7 | and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
8 |
9 | This script was started from an early version of the PyTorch ImageNet example
10 | (https://github.com/pytorch/examples/tree/master/imagenet)
11 |
12 | NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
13 | (https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
14 |
15 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
16 |
17 | Modifications by Zhiwei Hao (haozhw@bit.edu.cn) and Jianyuan Guo (jianyuan_guo@outlook.com)
18 | """
19 | import argparse
20 | import logging
21 | import os
22 | import time
23 | from collections import OrderedDict
24 | from contextlib import suppress
25 | from datetime import datetime
26 |
27 | import numpy as np
28 | import torch
29 | import torch.nn as nn
30 | import torchvision.utils
31 | import yaml
32 | from torch.nn.parallel import DistributedDataParallel as NativeDDP
33 | from torchvision import transforms
34 |
35 | from losses import CRD
36 | from register import config, register_forward
37 | from timm.data import AugMixDataset, create_dataset, create_loader, FastCollateMixup, Mixup, \
38 | resolve_data_config
39 | from timm.loss import *
40 | from timm.models import convert_splitbn_model, create_model, load_checkpoint, model_parameters, resume_checkpoint, \
41 | safe_model_name
42 | from timm.optim import create_optimizer_v2, optimizer_kwargs
43 | from timm.scheduler import create_scheduler
44 | from timm.utils import *
45 | from timm.utils import ApexScaler, NativeScaler
46 | from utils import ImageNetInstanceSample, process_feat, setup_default_logging, TimePredictor
47 | import models
48 |
49 | try:
50 | from apex import amp
51 | from apex.parallel import DistributedDataParallel as ApexDDP
52 | from apex.parallel import convert_syncbn_model
53 |
54 | has_apex = True
55 | except ImportError:
56 | has_apex = False
57 |
58 | has_native_amp = False
59 | try:
60 | if getattr(torch.cuda.amp, 'autocast') is not None:
61 | has_native_amp = True
62 | except AttributeError:
63 | pass
64 |
65 | try:
66 | import wandb
67 |
68 | has_wandb = True
69 | except ImportError:
70 | has_wandb = False
71 |
72 | torch.backends.cudnn.benchmark = True
73 | _logger = logging.getLogger('train')
74 |
75 | # The first arg parser parses out only the --config argument, this argument is used to
76 | # load a yaml file containing key-values that override the defaults for the main parser below
77 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
78 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
79 | help='YAML config file specifying default arguments')
80 |
81 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
82 |
83 | # -------------------------------------- Modified ---------------------------------------
84 | # KD parameters: DIST, DKD, and vanilla KD
85 | parser.add_argument('--kd-loss', default='crd', type=str)
86 | parser.add_argument('--teacher', default='beitv2_base_patch16_224', type=str)
87 | parser.add_argument('--teacher-pretrained', default=None, type=str) # teacher checkpoint path
88 | parser.add_argument('--ori-loss-weight', default=1., type=float)
89 | parser.add_argument('--kd-loss-weight', default=1., type=float)
90 | parser.add_argument('--teacher-resize', default=None, type=int)
91 | parser.add_argument('--student-resize', default=None, type=int)
92 | parser.add_argument('--input-size', default=None, nargs=3, type=int)
93 |
94 | # use torch.cuda.empty_cache() to save GPU memory
95 | parser.add_argument('--economic', action='store_true')
96 |
97 | # eval every 'eval-interval' epochs before epochs * eval_interval_end
98 | parser.add_argument('--eval-interval', type=int, default=1)
99 | parser.add_argument('--eval-interval-end', type=float, default=0.75)
100 | # ---------------------------------------------------------------------------------------
101 |
102 | # Dataset parameters
103 | parser.add_argument('data_dir', metavar='DIR',
104 | help='path to dataset')
105 | parser.add_argument('--dataset', '-d', metavar='NAME', default='',
106 | help='dataset type (default: ImageFolder/ImageTar if empty)')
107 | parser.add_argument('--train-split', metavar='NAME', default='train',
108 | help='dataset train split (default: train)')
109 | parser.add_argument('--val-split', metavar='NAME', default='validation',
110 | help='dataset validation split (default: validation)')
111 | parser.add_argument('--dataset-download', action='store_true', default=False,
112 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
113 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
114 | help='path to class to idx mapping file (default: "")')
115 |
116 | # Model parameters
117 | parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
118 | help='Name of model to train (default: "resnet50"')
119 | parser.add_argument('--pretrained', action='store_true', default=False,
120 | help='Start with pretrained version of specified network (if avail)')
121 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
122 | help='Initialize model from this checkpoint (default: none)')
123 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
124 | help='Resume full model and optimizer state from checkpoint (default: none)')
125 | parser.add_argument('--no-resume-opt', action='store_true', default=False,
126 | help='prevent resume of optimizer state when resuming model')
127 | parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
128 | help='number of label classes (Model default if None)')
129 | parser.add_argument('--gp', default=None, type=str, metavar='POOL',
130 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
131 | parser.add_argument('--img-size', type=int, default=None, metavar='N',
132 | help='Image patch size (default: None => model default)')
133 | parser.add_argument('--crop-pct', default=None, type=float,
134 | metavar='N', help='Input image center crop percent (for validation only)')
135 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
136 | help='Override mean pixel value of dataset')
137 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
138 | help='Override std deviation of dataset')
139 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
140 | help='Image resize interpolation type (overrides model)')
141 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
142 | help='Input batch size for training (default: 128)')
143 | parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
144 | help='Validation batch size override (default: None)')
145 | parser.add_argument('--channels-last', action='store_true', default=False,
146 | help='Use channels_last memory layout')
147 | parser.add_argument('--torchscript', dest='torchscript', action='store_true',
148 | help='torch.jit.script the full model')
149 | parser.add_argument('--fuser', default='', type=str,
150 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
151 | parser.add_argument('--grad-checkpointing', action='store_true', default=False,
152 | help='Enable gradient checkpointing through model blocks/stages')
153 |
154 | # Optimizer parameters
155 | parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
156 | help='Optimizer (default: "sgd"')
157 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
158 | help='Optimizer Epsilon (default: None, use opt default)')
159 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
160 | help='Optimizer Betas (default: None, use opt default)')
161 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
162 | help='Optimizer momentum (default: 0.9)')
163 | parser.add_argument('--weight-decay', type=float, default=2e-5,
164 | help='weight decay (default: 2e-5)')
165 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
166 | help='Clip gradient norm (default: None, no clipping)')
167 | parser.add_argument('--clip-mode', type=str, default='norm',
168 | help='Gradient clipping mode. One of ("norm", "value", "agc")')
169 | parser.add_argument('--layer-decay', type=float, default=None,
170 | help='layer-wise learning rate decay (default: None)')
171 |
172 | # Learning rate schedule parameters
173 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
174 | help='LR scheduler (default: "step"')
175 | parser.add_argument('--lr', type=float, default=0.05, metavar='LR',
176 | help='learning rate (default: 0.05)')
177 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
178 | help='learning rate noise on/off epoch percentages')
179 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
180 | help='learning rate noise limit percent (default: 0.67)')
181 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
182 | help='learning rate noise std-dev (default: 1.0)')
183 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
184 | help='learning rate cycle len multiplier (default: 1.0)')
185 | parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
186 | help='amount to decay each learning rate cycle (default: 0.5)')
187 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
188 | help='learning rate cycle limit, cycles enabled if > 1')
189 | parser.add_argument('--lr-k-decay', type=float, default=1.0,
190 | help='learning rate k-decay for cosine/poly (default: 1.0)')
191 | parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
192 | help='warmup learning rate (default: 0.0001)')
193 | parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
194 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
195 | parser.add_argument('--epochs', type=int, default=300, metavar='N',
196 | help='number of epochs to train (default: 300)')
197 | parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
198 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
199 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
200 | help='manual epoch number (useful on restarts)')
201 | parser.add_argument('--decay-epochs', type=float, default=100, metavar='N',
202 | help='epoch interval to decay LR')
203 | parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
204 | help='epochs to warmup LR, if scheduler supports')
205 | parser.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
206 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
207 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
208 | help='patience epochs for Plateau LR scheduler (default: 10')
209 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
210 | help='LR decay rate (default: 0.1)')
211 |
212 | # Augmentation & regularization parameters
213 | parser.add_argument('--no-aug', action='store_true', default=False,
214 | help='Disable all training augmentation, override other train aug args')
215 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
216 | help='Random resize scale (default: 0.08 1.0)')
217 | parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
218 | help='Random resize aspect ratio (default: 0.75 1.33)')
219 | parser.add_argument('--hflip', type=float, default=0.5,
220 | help='Horizontal flip training aug probability')
221 | parser.add_argument('--vflip', type=float, default=0.,
222 | help='Vertical flip training aug probability')
223 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
224 | help='Color jitter factor (default: 0.4)')
225 | parser.add_argument('--aa', type=str, default=None, metavar='NAME',
226 | help='Use AutoAugment policy. "v0" or "original". (default: None)'),
227 | parser.add_argument('--aug-repeats', type=float, default=0,
228 | help='Number of augmentation repetitions (distributed training only) (default: 0)')
229 | parser.add_argument('--aug-splits', type=int, default=0,
230 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
231 | parser.add_argument('--jsd-loss', action='store_true', default=False,
232 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
233 | parser.add_argument('--bce-loss', action='store_true', default=False,
234 | help='Enable BCE loss w/ Mixup/CutMix use.')
235 | parser.add_argument('--bce-target-thresh', type=float, default=None,
236 | help='Threshold for binarizing softened BCE targets (default: None, disabled)')
237 | parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
238 | help='Random erase prob (default: 0.)')
239 | parser.add_argument('--remode', type=str, default='pixel',
240 | help='Random erase mode (default: "pixel")')
241 | parser.add_argument('--recount', type=int, default=1,
242 | help='Random erase count (default: 1)')
243 | parser.add_argument('--resplit', action='store_true', default=False,
244 | help='Do not random erase first (clean) augmentation split')
245 | parser.add_argument('--mixup', type=float, default=0.0,
246 | help='mixup alpha, mixup enabled if > 0. (default: 0.)')
247 | parser.add_argument('--cutmix', type=float, default=0.0,
248 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
249 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
250 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
251 | parser.add_argument('--mixup-prob', type=float, default=1.0,
252 | help='Probability of performing mixup or cutmix when either/both is enabled')
253 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
254 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
255 | parser.add_argument('--mixup-mode', type=str, default='batch',
256 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
257 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
258 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
259 | parser.add_argument('--smoothing', type=float, default=0.1,
260 | help='Label smoothing (default: 0.1)')
261 | parser.add_argument('--train-interpolation', type=str, default='random',
262 | help='Training interpolation (random, bilinear, bicubic default: "random")')
263 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
264 | help='Dropout rate (default: 0.)')
265 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
266 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
267 | parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
268 | help='Drop path rate (default: None)')
269 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
270 | help='Drop block rate (default: None)')
271 |
272 | # Batch norm parameters (only works with gen_efficientnet based models currently)
273 | parser.add_argument('--bn-momentum', type=float, default=None,
274 | help='BatchNorm momentum override (if not None)')
275 | parser.add_argument('--bn-eps', type=float, default=None,
276 | help='BatchNorm epsilon override (if not None)')
277 | parser.add_argument('--sync-bn', action='store_true',
278 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
279 | parser.add_argument('--dist-bn', type=str, default='reduce',
280 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
281 | parser.add_argument('--split-bn', action='store_true',
282 | help='Enable separate BN layers per augmentation split.')
283 |
284 | # Model Exponential Moving Average
285 | parser.add_argument('--model-ema', action='store_true', default=False,
286 | help='Enable tracking moving average of model weights')
287 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
288 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
289 | parser.add_argument('--model-ema-decay', type=float, default=0.9998,
290 | help='decay factor for model weights moving average (default: 0.9998)')
291 |
292 | # Misc
293 | parser.add_argument('--seed', type=int, default=42, metavar='S',
294 | help='random seed (default: 42)')
295 | parser.add_argument('--worker-seeding', type=str, default='all',
296 | help='worker seed mode (default: all)')
297 | parser.add_argument('--log-interval', type=int, default=200, metavar='N',
298 | help='how many batches to wait before logging training status')
299 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
300 | help='how many batches to wait before writing recovery checkpoint')
301 | parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
302 | help='number of checkpoints to keep (default: 10)')
303 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
304 | help='how many training processes to use (default: 4)')
305 | parser.add_argument('--save-images', action='store_true', default=False,
306 | help='save images of input bathes every log interval for debugging')
307 | parser.add_argument('--amp', action='store_true', default=False,
308 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
309 | parser.add_argument('--apex-amp', action='store_true', default=False,
310 | help='Use NVIDIA Apex AMP mixed precision')
311 | parser.add_argument('--native-amp', action='store_true', default=False,
312 | help='Use Native Torch AMP mixed precision')
313 | parser.add_argument('--no-ddp-bb', action='store_true', default=False,
314 | help='Force broadcast buffers for native DDP to off.')
315 | parser.add_argument('--pin-mem', action='store_true', default=False,
316 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
317 | parser.add_argument('--no-prefetcher', action='store_true', default=False,
318 | help='disable fast prefetcher')
319 | parser.set_defaults(no_prefetcher=True)
320 | parser.add_argument('--output', default='', type=str, metavar='PATH',
321 | help='path to output folder (default: none, current dir)')
322 | parser.add_argument('--experiment', default='', type=str, metavar='NAME',
323 | help='name of train experiment, name of sub-folder for output')
324 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
325 | help='Best metric (default: "top1"')
326 | parser.add_argument('--tta', type=int, default=0, metavar='N',
327 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
328 | parser.add_argument("--local_rank", default=0, type=int)
329 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
330 | help='use the multi-epochs-loader to save time at the beginning of every epoch')
331 | parser.add_argument('--log-wandb', action='store_true', default=False,
332 | help='log training and validation metrics to wandb')
333 |
334 |
335 | def _parse_args():
336 | # Do we have a config file to parse?
337 | args_config, remaining = config_parser.parse_known_args()
338 | if args_config.config:
339 | with open(args_config.config, 'r') as f:
340 | cfg = yaml.safe_load(f)
341 | parser.set_defaults(**cfg)
342 |
343 | # The main arg parser parses the rest of the args, the usual
344 | # defaults will have been overridden if config file specified.
345 | args = parser.parse_args(remaining)
346 |
347 | # Cache the args as a text string to save them in the output dir later
348 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
349 | return args, args_text
350 |
351 |
352 | def main():
353 | setup_default_logging(_logger, log_path='train.log')
354 | args, args_text = _parse_args()
355 |
356 | if args.log_wandb:
357 | if has_wandb:
358 | wandb.init(project=args.experiment, config=args)
359 | else:
360 | _logger.warning("You've requested to log metrics to wandb but package not found. "
361 | "Metrics not being logged to wandb, try `pip install wandb`")
362 |
363 | args.prefetcher = not args.no_prefetcher
364 | args.distributed = False
365 | if 'WORLD_SIZE' in os.environ:
366 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
367 | args.device = 'cuda:0'
368 | args.world_size = 1
369 | args.rank = 0 # global rank
370 | if args.distributed:
371 | args.device = 'cuda:%d' % args.local_rank
372 | torch.cuda.set_device(args.local_rank)
373 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
374 | args.world_size = torch.distributed.get_world_size()
375 | args.rank = torch.distributed.get_rank()
376 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
377 | % (args.rank, args.world_size))
378 | else:
379 | _logger.info('Training with a single process on 1 GPUs.')
380 | assert args.rank >= 0
381 |
382 | # resolve AMP arguments based on PyTorch / Apex availability
383 | use_amp = None
384 | if args.amp:
385 | # `--amp` chooses native amp before apex (APEX ver not actively maintained)
386 | if has_native_amp:
387 | args.native_amp = True
388 | elif has_apex:
389 | args.apex_amp = True
390 | if args.apex_amp and has_apex:
391 | use_amp = 'apex'
392 | elif args.native_amp and has_native_amp:
393 | use_amp = 'native'
394 | elif args.apex_amp or args.native_amp:
395 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
396 | "Install NVIDA apex or upgrade to PyTorch 1.6")
397 |
398 | random_seed(args.seed, args.rank)
399 |
400 | if args.fuser:
401 | set_jit_fuser(args.fuser)
402 |
403 | # ------------------------------------ Modified -------------------------------------
404 | teacher = create_model(
405 | args.teacher,
406 | checkpoint_path=args.teacher_pretrained,
407 | num_classes=args.num_classes)
408 | register_forward(teacher)
409 | teacher = teacher.cuda()
410 | teacher.eval()
411 | # ------------------------------------------------------------------------------------
412 |
413 | model = create_model(
414 | args.model,
415 | pretrained=args.pretrained,
416 | num_classes=args.num_classes,
417 | drop_rate=args.drop,
418 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
419 | drop_path_rate=args.drop_path,
420 | drop_block_rate=args.drop_block,
421 | global_pool=args.gp,
422 | bn_momentum=args.bn_momentum,
423 | bn_eps=args.bn_eps,
424 | scriptable=args.torchscript,
425 | checkpoint_path=args.initial_checkpoint)
426 | register_forward(model)
427 |
428 | if args.num_classes is None:
429 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
430 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
431 |
432 | if args.grad_checkpointing:
433 | model.set_grad_checkpointing(enable=True)
434 |
435 | if args.local_rank == 0:
436 | _logger.info(
437 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
438 |
439 | data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
440 |
441 | # setup augmentation batch splits for contrastive loss or split bn
442 | num_aug_splits = 0
443 | if args.aug_splits > 0:
444 | assert args.aug_splits > 1, 'A split of 1 makes no sense'
445 | num_aug_splits = args.aug_splits
446 |
447 | # enable split bn (separate bn stats per batch-portion)
448 | if args.split_bn:
449 | assert num_aug_splits > 1 or args.resplit
450 | model = convert_splitbn_model(model, max(num_aug_splits, 2))
451 |
452 | # create the train and eval datasets
453 | dataset_train = ImageNetInstanceSample(root=f'{args.data_dir}/train', name=args.dataset, class_map=args.class_map,
454 | load_bytes=False, is_sample=True, k=16384)
455 | dataset_eval = create_dataset(
456 | args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
457 | class_map=args.class_map,
458 | download=args.dataset_download,
459 | batch_size=args.batch_size)
460 |
461 | # move model to GPU, enable channels last layout if set
462 | if args.kd_loss == 'crd':
463 | kd_loss_fn = CRD(feat_s_channel=config.get_pre_logit_dim(args.model),
464 | feat_t_channel=config.get_pre_logit_dim(args.teacher),
465 | feat_dim=128, num_data=len(dataset_train), k=16384,
466 | momentum=0.5, temperature=0.07)
467 | requires_feat = True
468 | else:
469 | raise NotImplementedError(f'this script only supports crd loss. for other kd loss, please refer to train-kd.py')
470 |
471 | model.kd_loss_fn = kd_loss_fn
472 | model.cuda()
473 |
474 | if args.channels_last:
475 | model = model.to(memory_format=torch.channels_last)
476 |
477 | # setup synchronized BatchNorm for distributed training
478 | if args.distributed and args.sync_bn:
479 | assert not args.split_bn
480 | if has_apex and use_amp == 'apex':
481 | # Apex SyncBN preferred unless native amp is activated
482 | model = convert_syncbn_model(model)
483 | else:
484 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
485 | if args.local_rank == 0:
486 | _logger.info(
487 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
488 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
489 |
490 | if args.torchscript:
491 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
492 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
493 | model = torch.jit.script(model)
494 |
495 | optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
496 |
497 | # setup automatic mixed-precision (AMP) loss scaling and op casting
498 | amp_autocast = suppress # do nothing
499 | loss_scaler = None
500 | if use_amp == 'apex':
501 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
502 | loss_scaler = ApexScaler()
503 | if args.local_rank == 0:
504 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
505 | elif use_amp == 'native':
506 | amp_autocast = torch.cuda.amp.autocast
507 | loss_scaler = NativeScaler()
508 | if args.local_rank == 0:
509 | _logger.info('Using native Torch AMP. Training in mixed precision.')
510 | else:
511 | if args.local_rank == 0:
512 | _logger.info('AMP not enabled. Training in float32.')
513 |
514 | # optionally resume from a checkpoint
515 | resume_epoch = None
516 | if args.resume:
517 | resume_epoch = resume_checkpoint(
518 | model, args.resume,
519 | optimizer=None if args.no_resume_opt else optimizer,
520 | loss_scaler=None if args.no_resume_opt else loss_scaler,
521 | log_info=args.local_rank == 0)
522 |
523 | # setup exponential moving average of model weights, SWA could be used here too
524 | model_ema = None
525 | if args.model_ema:
526 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
527 | model_ema = ModelEmaV2(
528 | model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
529 | if args.resume:
530 | load_checkpoint(model_ema.module, args.resume, use_ema=True)
531 |
532 | # setup distributed training
533 | if args.distributed:
534 | if has_apex and use_amp == 'apex':
535 | # Apex DDP preferred unless native amp is activated
536 | if args.local_rank == 0:
537 | _logger.info("Using NVIDIA APEX DistributedDataParallel.")
538 | model = ApexDDP(model, delay_allreduce=True)
539 | else:
540 | if args.local_rank == 0:
541 | _logger.info("Using native Torch DistributedDataParallel.")
542 | model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
543 | # NOTE: EMA model does not need to be wrapped by DDP
544 |
545 | # setup learning rate schedule and starting epoch
546 | lr_scheduler, num_epochs = create_scheduler(args, optimizer)
547 | start_epoch = 0
548 | if args.start_epoch is not None:
549 | # a specified start_epoch will always override the resume epoch
550 | start_epoch = args.start_epoch
551 | elif resume_epoch is not None:
552 | start_epoch = resume_epoch
553 | if lr_scheduler is not None and start_epoch > 0:
554 | lr_scheduler.step(start_epoch)
555 |
556 | if args.local_rank == 0:
557 | _logger.info('Scheduled epochs: {}'.format(num_epochs))
558 |
559 | # setup mixup / cutmix
560 | collate_fn = None
561 | mixup_fn = None
562 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
563 | if mixup_active:
564 | mixup_args = dict(
565 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
566 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
567 | label_smoothing=args.smoothing, num_classes=args.num_classes)
568 | if args.prefetcher:
569 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
570 | collate_fn = FastCollateMixup(**mixup_args)
571 | else:
572 | mixup_fn = Mixup(**mixup_args)
573 |
574 | # wrap dataset in AugMix helper
575 | if num_aug_splits > 1:
576 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
577 |
578 | # create data loaders w/ augmentation pipeiine
579 | train_interpolation = args.train_interpolation
580 | if args.no_aug or not train_interpolation:
581 | train_interpolation = data_config['interpolation']
582 | loader_train = create_loader(
583 | dataset_train,
584 | input_size=data_config['input_size'],
585 | batch_size=args.batch_size,
586 | is_training=True,
587 | use_prefetcher=args.prefetcher,
588 | no_aug=args.no_aug,
589 | re_prob=args.reprob,
590 | re_mode=args.remode,
591 | re_count=args.recount,
592 | re_split=args.resplit,
593 | scale=args.scale,
594 | ratio=args.ratio,
595 | hflip=args.hflip,
596 | vflip=args.vflip,
597 | color_jitter=args.color_jitter,
598 | auto_augment=args.aa,
599 | num_aug_repeats=args.aug_repeats,
600 | num_aug_splits=num_aug_splits,
601 | interpolation=train_interpolation,
602 | mean=data_config['mean'],
603 | std=data_config['std'],
604 | num_workers=args.workers,
605 | distributed=args.distributed,
606 | collate_fn=collate_fn,
607 | pin_memory=args.pin_mem,
608 | use_multi_epochs_loader=args.use_multi_epochs_loader,
609 | worker_seeding=args.worker_seeding,
610 | )
611 |
612 | loader_eval = create_loader(
613 | dataset_eval,
614 | input_size=(3, 224, 224),
615 | batch_size=args.validation_batch_size or args.batch_size,
616 | is_training=False,
617 | use_prefetcher=args.prefetcher,
618 | interpolation=data_config['interpolation'],
619 | mean=data_config['mean'],
620 | std=data_config['std'],
621 | num_workers=args.workers,
622 | distributed=args.distributed,
623 | crop_pct=data_config['crop_pct'],
624 | pin_memory=args.pin_mem,
625 | )
626 |
627 | # setup loss function
628 | if args.jsd_loss:
629 | assert num_aug_splits > 1 # JSD only valid with aug splits set
630 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
631 | elif mixup_active:
632 | # smoothing is handled with mixup target transform which outputs sparse, soft targets
633 | if args.bce_loss:
634 | train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
635 | else:
636 | train_loss_fn = SoftTargetCrossEntropy()
637 | elif args.smoothing:
638 | if args.bce_loss:
639 | train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
640 | else:
641 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
642 | else:
643 | train_loss_fn = nn.CrossEntropyLoss()
644 | train_loss_fn = train_loss_fn.cuda()
645 | validate_loss_fn = nn.CrossEntropyLoss().cuda()
646 |
647 | teacher_resizer = student_resizer = None
648 | if args.teacher_resize is not None:
649 | teacher_resizer = transforms.Resize(args.teacher_resize).cuda()
650 | if args.student_resize is not None:
651 | student_resizer = transforms.Resize(args.student_resize).cuda()
652 |
653 | # setup checkpoint saver and eval metric tracking
654 | eval_metric = args.eval_metric
655 | best_metric = None
656 | best_epoch = None
657 | saver = None
658 | ema_saver = None
659 | output_dir = None
660 | if args.rank == 0:
661 | if args.experiment:
662 | exp_name = args.experiment
663 | else:
664 | exp_name = '-'.join([
665 | datetime.now().strftime("%Y%m%d-%H%M%S"),
666 | safe_model_name(args.model),
667 | str(data_config['input_size'][-1])
668 | ])
669 | output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
670 | decreasing = True if eval_metric == 'loss' else False
671 | saver_dir = os.path.join(output_dir, 'checkpoint')
672 | os.makedirs(saver_dir)
673 | saver = CheckpointSaver(
674 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
675 | checkpoint_dir=saver_dir, recovery_dir=saver_dir, decreasing=decreasing,
676 | max_history=args.checkpoint_hist)
677 | if model_ema is not None:
678 | ema_saver_dir = os.path.join(output_dir, 'ema_checkpoint')
679 | os.makedirs(ema_saver_dir)
680 | ema_saver = CheckpointSaver(
681 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
682 | checkpoint_dir=ema_saver_dir, recovery_dir=ema_saver_dir, decreasing=decreasing,
683 | max_history=args.checkpoint_hist)
684 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
685 | f.write(args_text)
686 |
687 | try:
688 | tp = TimePredictor(num_epochs - start_epoch)
689 | for epoch in range(start_epoch, num_epochs):
690 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
691 | loader_train.sampler.set_epoch(epoch)
692 |
693 | train_metrics = train_one_epoch(
694 | epoch, model, teacher, loader_train, optimizer, train_loss_fn, args,
695 | requires_feat=requires_feat, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
696 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
697 | teacher_resizer=teacher_resizer, student_resizer=student_resizer)
698 |
699 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
700 | if args.local_rank == 0:
701 | _logger.info("Distributing BatchNorm running means and vars")
702 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
703 |
704 | is_eval = epoch > int(args.eval_interval_end * args.epochs) or epoch % args.eval_interval == 0
705 | if is_eval:
706 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
707 |
708 | if saver is not None:
709 | # save proper checkpoint with eval metric
710 | save_metric = eval_metrics[eval_metric]
711 | best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
712 |
713 | if model_ema is not None and not args.model_ema_force_cpu:
714 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
715 | distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
716 | ema_eval_metrics = validate(
717 | model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast,
718 | log_suffix=' (EMA)')
719 | eval_metrics = ema_eval_metrics
720 |
721 | if ema_saver is not None:
722 | # save proper checkpoint with eval metric
723 | save_metric = eval_metrics[eval_metric]
724 | best_metric, best_epoch = ema_saver.save_checkpoint(epoch, metric=save_metric)
725 |
726 | if output_dir is not None:
727 | update_summary(
728 | epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
729 | write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
730 |
731 | metrics = eval_metrics[eval_metric]
732 | else:
733 | metrics = None
734 |
735 | if lr_scheduler is not None:
736 | # step LR for next epoch
737 | lr_scheduler.step(epoch + 1, metrics)
738 |
739 | tp.update()
740 | if args.rank == 0:
741 | print(f'Will finish at {tp.get_pred_text()}')
742 | print(f'Avg running time of latest {len(tp.time_list)} epochs: {np.mean(tp.time_list):.2f}s/ep.')
743 |
744 | except KeyboardInterrupt:
745 | pass
746 | if best_metric is not None:
747 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
748 |
749 | if args.rank == 0:
750 | os.system(f'mv train.log {output_dir}')
751 |
752 |
753 | def train_one_epoch(
754 | epoch, model, teacher, loader, optimizer, loss_fn, args, requires_feat=False,
755 | lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, loss_scaler=None,
756 | model_ema=None, mixup_fn=None, teacher_resizer=None, student_resizer=None):
757 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
758 | if args.prefetcher and loader.mixup_enabled:
759 | loader.mixup_enabled = False
760 | elif mixup_fn is not None:
761 | mixup_fn.mixup_enabled = False
762 |
763 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
764 | batch_time_m = AverageMeter()
765 | data_time_m = AverageMeter()
766 | losses_m = AverageMeter()
767 | losses_ori_m = AverageMeter()
768 | losses_kd_m = AverageMeter()
769 |
770 | model.train()
771 |
772 | end = time.time()
773 | last_idx = len(loader) - 1
774 | num_updates = epoch * len(loader)
775 | for batch_idx, (input, target, index, contrastive_index) in enumerate(loader):
776 | last_batch = batch_idx == last_idx
777 | data_time_m.update(time.time() - end)
778 | if not args.prefetcher:
779 | input, target = input.cuda(), target.cuda()
780 | if mixup_fn is not None:
781 | input, target = mixup_fn(input, target)
782 | if args.channels_last:
783 | input = input.contiguous(memory_format=torch.channels_last)
784 | index = index.cuda()
785 | contrastive_index = contrastive_index.cuda()
786 |
787 | with amp_autocast():
788 |
789 | # --------------------------------- Modified --------------------------------
790 | if student_resizer is not None:
791 | student_input = student_resizer(input)
792 | else:
793 | student_input = input
794 | if teacher_resizer is not None:
795 | teacher_input = teacher_resizer(input)
796 | else:
797 | teacher_input = input
798 |
799 | if args.economic:
800 | torch.cuda.empty_cache()
801 |
802 | with torch.no_grad():
803 | output_t, feat_t = teacher(teacher_input, requires_feat=requires_feat)
804 |
805 | if args.economic:
806 | torch.cuda.empty_cache()
807 |
808 | output, feat = model(student_input, requires_feat=requires_feat)
809 |
810 | loss_ori = args.ori_loss_weight * loss_fn(output, target)
811 |
812 | try:
813 | kd_loss_fn = model.module.kd_loss_fn
814 | except AttributeError:
815 | kd_loss_fn = model.kd_loss_fn
816 | loss_kd = args.kd_loss_weight * kd_loss_fn(z_s=output, z_t=output_t,
817 | feature_student=process_feat(kd_loss_fn, feat),
818 | feature_teacher=process_feat(kd_loss_fn, feat_t),
819 | index=index, contrastive_index=contrastive_index)
820 | loss = loss_ori + loss_kd
821 | # ----------------------------------------------------------------------------
822 |
823 | if not args.distributed:
824 | # --------------------------------- Modified --------------------------------
825 | losses_m.update(loss.item(), input.size(0))
826 | losses_ori_m.update(loss_ori.item(), input.size(0))
827 | losses_kd_m.update(loss_kd.item(), input.size(0))
828 | # ----------------------------------------------------------------------------
829 |
830 | optimizer.zero_grad()
831 | if loss_scaler is not None:
832 | loss_scaler(
833 | loss, optimizer,
834 | clip_grad=args.clip_grad, clip_mode=args.clip_mode,
835 | parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
836 | create_graph=second_order)
837 | else:
838 | loss.backward(create_graph=second_order)
839 | if args.clip_grad is not None:
840 | dispatch_clip_grad(
841 | model_parameters(model, exclude_head='agc' in args.clip_mode),
842 | value=args.clip_grad, mode=args.clip_mode)
843 | optimizer.step()
844 |
845 | if model_ema is not None:
846 | model_ema.update(model)
847 |
848 | torch.cuda.synchronize()
849 | num_updates += 1
850 | batch_time_m.update(time.time() - end)
851 | if last_batch or batch_idx % args.log_interval == 0:
852 | lrl = [param_group['lr'] for param_group in optimizer.param_groups]
853 | lr = sum(lrl) / len(lrl)
854 |
855 | # ----------------------------------- Modified ----------------------------------
856 | if args.distributed:
857 | reduced_loss = reduce_tensor(loss.data, args.world_size)
858 | reduced_loss_ori = reduce_tensor(loss_ori.data, args.world_size)
859 | reduced_loss_kd = reduce_tensor(loss_kd.data, args.world_size)
860 | losses_m.update(reduced_loss.item(), input.size(0))
861 | losses_ori_m.update(reduced_loss_ori.item(), input.size(0))
862 | losses_kd_m.update(reduced_loss_kd.item(), input.size(0))
863 |
864 | if args.local_rank == 0:
865 | _logger.info(
866 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
867 | 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
868 | 'Loss_ori: {loss_ori.val:#.4g} ({loss_ori.avg:#.3g}) '
869 | 'Loss_kd: {loss_kd.val:#.4g} ({loss_kd.avg:#.3g}) '
870 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
871 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
872 | 'LR: {lr:.3e} '
873 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
874 | epoch,
875 | batch_idx, len(loader),
876 | 100. * batch_idx / last_idx,
877 | loss=losses_m,
878 | loss_ori=losses_ori_m,
879 | loss_kd=losses_kd_m,
880 | batch_time=batch_time_m,
881 | rate=input.size(0) * args.world_size / batch_time_m.val,
882 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
883 | lr=lr,
884 | data_time=data_time_m))
885 | # --------------------------------------------------------------------------------
886 |
887 | if args.save_images and output_dir:
888 | torchvision.utils.save_image(
889 | input,
890 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
891 | padding=0,
892 | normalize=True)
893 |
894 | if saver is not None and args.recovery_interval and (
895 | last_batch or (batch_idx + 1) % args.recovery_interval == 0):
896 | saver.save_recovery(epoch, batch_idx=batch_idx)
897 |
898 | if lr_scheduler is not None:
899 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
900 |
901 | end = time.time()
902 | # end for
903 |
904 | if hasattr(optimizer, 'sync_lookahead'):
905 | optimizer.sync_lookahead()
906 |
907 | return OrderedDict([('loss', losses_m.avg)])
908 |
909 |
910 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
911 | batch_time_m = AverageMeter()
912 | losses_m = AverageMeter()
913 | top1_m = AverageMeter()
914 | top5_m = AverageMeter()
915 |
916 | model.eval()
917 |
918 | end = time.time()
919 | last_idx = len(loader) - 1
920 | with torch.no_grad():
921 | for batch_idx, (input, target) in enumerate(loader):
922 | last_batch = batch_idx == last_idx
923 | if not args.prefetcher:
924 | input = input.cuda()
925 | target = target.cuda()
926 | if args.channels_last:
927 | input = input.contiguous(memory_format=torch.channels_last)
928 |
929 | with amp_autocast():
930 | output = model(input)
931 | if isinstance(output, (tuple, list)):
932 | output = output[0]
933 |
934 | # augmentation reduction
935 | reduce_factor = args.tta
936 | if reduce_factor > 1:
937 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
938 | target = target[0:target.size(0):reduce_factor]
939 |
940 | loss = loss_fn(output, target)
941 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
942 |
943 | if args.distributed:
944 | reduced_loss = reduce_tensor(loss.data, args.world_size)
945 | acc1 = reduce_tensor(acc1, args.world_size)
946 | acc5 = reduce_tensor(acc5, args.world_size)
947 | else:
948 | reduced_loss = loss.data
949 |
950 | torch.cuda.synchronize()
951 |
952 | losses_m.update(reduced_loss.item(), input.size(0))
953 | top1_m.update(acc1.item(), output.size(0))
954 | top5_m.update(acc5.item(), output.size(0))
955 |
956 | batch_time_m.update(time.time() - end)
957 | end = time.time()
958 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
959 | log_name = 'Test' + log_suffix
960 | _logger.info(
961 | '{0}: [{1:>4d}/{2}] '
962 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
963 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
964 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
965 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
966 | log_name, batch_idx, last_idx, batch_time=batch_time_m,
967 | loss=losses_m, top1=top1_m, top5=top5_m))
968 |
969 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
970 |
971 | return metrics
972 |
973 |
974 | if __name__ == '__main__':
975 | main()
976 |
--------------------------------------------------------------------------------
/train-fd.py:
--------------------------------------------------------------------------------
1 | # !/usr/bin/env python3
2 | """ ImageNet Training Script
3 |
4 | This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
5 | training results with some of the latest networks and training techniques. It favours canonical PyTorch
6 | and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
7 | and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
8 |
9 | This script was started from an early version of the PyTorch ImageNet example
10 | (https://github.com/pytorch/examples/tree/master/imagenet)
11 |
12 | NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
13 | (https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
14 |
15 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
16 |
17 | Modifications by Zhiwei Hao (haozhw@bit.edu.cn) and Jianyuan Guo (jianyuan_guo@outlook.com)
18 | """
19 | import argparse
20 | import logging
21 | import logging.handlers
22 | import os
23 | import time
24 | from collections import OrderedDict
25 | from contextlib import suppress
26 | from copy import deepcopy
27 | from datetime import datetime
28 |
29 | import numpy as np
30 | import torch
31 | import torch.nn as nn
32 | import yaml
33 | from torch.nn.parallel import DistributedDataParallel as NativeDDP
34 | from torchvision import transforms
35 |
36 | from losses import Correlation, ReviewKD, RKD
37 | from register import config, register_forward
38 | from timm.data import AugMixDataset, create_dataset, create_loader, FastCollateMixup, resolve_data_config
39 | from timm.loss import *
40 | from timm.models import convert_splitbn_model, create_model, model_parameters, safe_model_name
41 | from timm.optim import create_optimizer_v2
42 | from timm.scheduler import create_scheduler
43 | from timm.utils import *
44 | from timm.utils import ApexScaler, NativeScaler
45 | from utils import CheckpointSaverWithLogger, MultiSmoothingMixup, process_feat, setup_default_logging, TimePredictor
46 | import models
47 |
48 | try:
49 | from apex import amp
50 | from apex.parallel import DistributedDataParallel as ApexDDP
51 | from apex.parallel import convert_syncbn_model
52 |
53 | has_apex = True
54 | except ImportError:
55 | has_apex = False
56 |
57 | has_native_amp = False
58 | try:
59 | if getattr(torch.cuda.amp, 'autocast') is not None:
60 | has_native_amp = True
61 | except AttributeError:
62 | pass
63 |
64 | try:
65 | import wandb
66 |
67 | has_wandb = True
68 | except ImportError:
69 | has_wandb = False
70 |
71 | torch.backends.cudnn.benchmark = True
72 | _logger = logging.getLogger('train')
73 |
74 | # The first arg parser parses out only the --config argument, this argument is used to
75 | # load a yaml file containing key-values that override the defaults for the main parser below
76 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
77 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
78 | help='YAML config file specifying default arguments')
79 |
80 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
81 |
82 | # ---------------------------------------------------------------------------------------
83 | # KD parameters: CC, review, and RKD
84 | parser.add_argument('--teacher', default='beitv2_base_patch16_224', type=str)
85 | parser.add_argument('--teacher-pretrained', default=None, type=str) # teacher checkpoint path
86 | parser.add_argument('--teacher-resize', default=None, type=int)
87 | parser.add_argument('--student-resize', default=None, type=int)
88 | parser.add_argument('--input-size', default=None, nargs=3, type=int)
89 |
90 | # use torch.cuda.empty_cache() to save GPU memory
91 | parser.add_argument('--economic', action='store_true')
92 |
93 | # eval every 'eval-interval' epochs before epochs * eval_interval_end
94 | parser.add_argument('--eval-interval', type=int, default=1)
95 | parser.add_argument('--eval-interval-end', type=float, default=0.75)
96 |
97 | # one teacher forward, multiple students trained for saving time
98 | # e.g. --kd_loss kd dist --kd_loss_weight 1 2
99 | # then there are 2 settings which are (kd_loss=kd, kd_loss_weight=1) and (kd_loss=dist, kd_loss_weight=2),
100 | # but not 2 * 2 = 4 in total
101 | _nargs_attrs = ['kd_loss', 'kd_loss_weight', 'ori_loss_weight', 'model', 'opt', 'clip_grad', 'lr',
102 | 'weight_decay', 'drop', 'drop_path', 'model_ema_decay', 'smoothing', 'bce_loss']
103 |
104 | parser.add_argument('--kd-loss', default=['kd'], type=str, nargs='+')
105 | parser.add_argument('--kd-loss-weight', default=[1.], type=float, nargs='+')
106 | parser.add_argument('--ori-loss-weight', default=[1.], type=float, nargs='+')
107 | parser.add_argument('--model', default=['resnet50'], type=str, nargs='+')
108 | parser.add_argument('--opt', default=['sgd'], type=str, nargs='+')
109 | parser.add_argument('--clip-grad', type=float, default=[None], nargs='+')
110 | parser.add_argument('--lr', default=[0.05], type=float, nargs='+')
111 | parser.add_argument('--weight-decay', type=float, default=[2e-5], nargs='+')
112 | parser.add_argument('--drop', type=float, default=[0.0], nargs='+')
113 | parser.add_argument('--drop-path', type=float, default=[None], nargs='+')
114 | parser.add_argument('--model-ema-decay', type=float, default=[0.9998], nargs='+')
115 | parser.add_argument('--smoothing', type=float, default=[0.1], nargs='+')
116 | parser.add_argument('--bce-loss', type=int, default=[0], nargs='+') # 0: disable; others: enable
117 | # ---------------------------------------------------------------------------------------
118 |
119 | # Dataset parameters
120 | parser.add_argument('data_dir', metavar='DIR',
121 | help='path to dataset')
122 | parser.add_argument('--dataset', '-d', metavar='NAME', default='',
123 | help='dataset type (default: ImageFolder/ImageTar if empty)')
124 | parser.add_argument('--train-split', metavar='NAME', default='train',
125 | help='dataset train split (default: train)')
126 | parser.add_argument('--val-split', metavar='NAME', default='validation',
127 | help='dataset validation split (default: validation)')
128 | parser.add_argument('--dataset-download', action='store_true', default=False,
129 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
130 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
131 | help='path to class to idx mapping file (default: "")')
132 |
133 | # Model parameters
134 | parser.add_argument('--pretrained', action='store_true', default=False,
135 | help='Start with pretrained version of specified network (if avail)')
136 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
137 | help='Initialize model from this checkpoint (default: none)')
138 | parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
139 | help='number of label classes (Model default if None)')
140 | parser.add_argument('--gp', default=None, type=str, metavar='POOL',
141 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
142 | parser.add_argument('--img-size', type=int, default=None, metavar='N',
143 | help='Image patch size (default: None => model default)')
144 | parser.add_argument('--crop-pct', default=None, type=float,
145 | metavar='N', help='Input image center crop percent (for validation only)')
146 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
147 | help='Override mean pixel value of dataset')
148 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
149 | help='Override std deviation of dataset')
150 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
151 | help='Image resize interpolation type (overrides model)')
152 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
153 | help='Input batch size for training (default: 128)')
154 | parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
155 | help='Validation batch size override (default: None)')
156 | parser.add_argument('--channels-last', action='store_true', default=False,
157 | help='Use channels_last memory layout')
158 | parser.add_argument('--torchscript', dest='torchscript', action='store_true',
159 | help='torch.jit.script the full model')
160 | parser.add_argument('--fuser', default='', type=str,
161 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
162 |
163 | # Optimizer parameters
164 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
165 | help='Optimizer Epsilon (default: None, use opt default)')
166 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
167 | help='Optimizer Betas (default: None, use opt default)')
168 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
169 | help='Optimizer momentum (default: 0.9)')
170 | parser.add_argument('--clip-mode', type=str, default='norm',
171 | help='Gradient clipping mode. One of ("norm", "value", "agc")')
172 | parser.add_argument('--layer-decay', type=float, default=None,
173 | help='layer-wise learning rate decay (default: None)')
174 |
175 | # Learning rate schedule parameters
176 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
177 | help='LR scheduler (default: "step"')
178 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
179 | help='learning rate noise on/off epoch percentages')
180 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
181 | help='learning rate noise limit percent (default: 0.67)')
182 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
183 | help='learning rate noise std-dev (default: 1.0)')
184 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
185 | help='learning rate cycle len multiplier (default: 1.0)')
186 | parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
187 | help='amount to decay each learning rate cycle (default: 0.5)')
188 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
189 | help='learning rate cycle limit, cycles enabled if > 1')
190 | parser.add_argument('--lr-k-decay', type=float, default=1.0,
191 | help='learning rate k-decay for cosine/poly (default: 1.0)')
192 | parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
193 | help='warmup learning rate (default: 0.0001)')
194 | parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
195 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
196 | parser.add_argument('--epochs', type=int, default=300, metavar='N',
197 | help='number of epochs to train (default: 300)')
198 | parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
199 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
200 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
201 | help='manual epoch number (useful on restarts)')
202 | parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
203 | help='epochs to warmup LR, if scheduler supports')
204 | parser.add_argument('--decay-epochs', type=float, default=100, metavar='N',
205 | help='epoch interval to decay LR')
206 | parser.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
207 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
208 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
209 | help='patience epochs for Plateau LR scheduler (default: 10')
210 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
211 | help='LR decay rate (default: 0.1)')
212 |
213 | # Augmentation & regularization parameters
214 | parser.add_argument('--no-aug', action='store_true', default=False,
215 | help='Disable all training augmentation, override other train aug args')
216 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
217 | help='Random resize scale (default: 0.08 1.0)')
218 | parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
219 | help='Random resize aspect ratio (default: 0.75 1.33)')
220 | parser.add_argument('--hflip', type=float, default=0.5,
221 | help='Horizontal flip training aug probability')
222 | parser.add_argument('--vflip', type=float, default=0.,
223 | help='Vertical flip training aug probability')
224 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
225 | help='Color jitter factor (default: 0.4)')
226 | parser.add_argument('--aa', type=str, default=None, metavar='NAME',
227 | help='Use AutoAugment policy. "v0" or "original". (default: None)'),
228 | parser.add_argument('--aug-repeats', type=float, default=0,
229 | help='Number of augmentation repetitions (distributed training only) (default: 0)')
230 | parser.add_argument('--aug-splits', type=int, default=0,
231 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
232 | parser.add_argument('--jsd-loss', action='store_true', default=False,
233 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
234 | parser.add_argument('--bce-target-thresh', type=float, default=None,
235 | help='Threshold for binarizing softened BCE targets (default: None, disabled)')
236 | parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
237 | help='Random erase prob (default: 0.)')
238 | parser.add_argument('--remode', type=str, default='pixel',
239 | help='Random erase mode (default: "pixel")')
240 | parser.add_argument('--recount', type=int, default=1,
241 | help='Random erase count (default: 1)')
242 | parser.add_argument('--resplit', action='store_true', default=False,
243 | help='Do not random erase first (clean) augmentation split')
244 | parser.add_argument('--mixup', type=float, default=0.0,
245 | help='mixup alpha, mixup enabled if > 0. (default: 0.)')
246 | parser.add_argument('--cutmix', type=float, default=0.0,
247 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
248 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
249 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
250 | parser.add_argument('--mixup-prob', type=float, default=1.0,
251 | help='Probability of performing mixup or cutmix when either/both is enabled')
252 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
253 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
254 | parser.add_argument('--mixup-mode', type=str, default='batch',
255 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
256 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
257 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
258 | parser.add_argument('--train-interpolation', type=str, default='random',
259 | help='Training interpolation (random, bilinear, bicubic default: "random")')
260 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
261 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
262 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
263 | help='Drop block rate (default: None)')
264 |
265 | # Batch norm parameters (only works with gen_efficientnet based models currently)
266 | parser.add_argument('--bn-momentum', type=float, default=None,
267 | help='BatchNorm momentum override (if not None)')
268 | parser.add_argument('--bn-eps', type=float, default=None,
269 | help='BatchNorm epsilon override (if not None)')
270 | parser.add_argument('--sync-bn', action='store_true',
271 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
272 | parser.add_argument('--dist-bn', type=str, default='reduce',
273 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
274 | parser.add_argument('--split-bn', action='store_true',
275 | help='Enable separate BN layers per augmentation split.')
276 |
277 | # Model Exponential Moving Average
278 | parser.add_argument('--model-ema', action='store_true', default=False,
279 | help='Enable tracking moving average of model weights')
280 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
281 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
282 |
283 | # Misc
284 | parser.add_argument('--seed', type=int, default=42, metavar='S',
285 | help='random seed (default: 42)')
286 | parser.add_argument('--worker-seeding', type=str, default='all',
287 | help='worker seed mode (default: all)')
288 | parser.add_argument('--log-interval', type=int, default=200, metavar='N',
289 | help='how many batches to wait before logging training status')
290 | parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
291 | help='number of checkpoints to keep (default: 10)')
292 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
293 | help='how many training processes to use (default: 4)')
294 | parser.add_argument('--amp', action='store_true', default=False,
295 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
296 | parser.add_argument('--apex-amp', action='store_true', default=False,
297 | help='Use NVIDIA Apex AMP mixed precision')
298 | parser.add_argument('--native-amp', action='store_true', default=False,
299 | help='Use Native Torch AMP mixed precision')
300 | parser.add_argument('--no-ddp-bb', action='store_true', default=False,
301 | help='Force broadcast buffers for native DDP to off.')
302 | parser.add_argument('--pin-mem', action='store_true', default=False,
303 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
304 | parser.add_argument('--no-prefetcher', action='store_true', default=False,
305 | help='disable fast prefetcher')
306 | parser.add_argument('--output', default='', type=str, metavar='PATH',
307 | help='path to output folder (default: none, current dir)')
308 | parser.add_argument('--experiment', default='', type=str, metavar='NAME',
309 | help='name of train experiment, name of sub-folder for output')
310 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
311 | help='Best metric (default: "top1"')
312 | parser.add_argument('--tta', type=int, default=0, metavar='N',
313 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
314 | parser.add_argument("--local_rank", default=0, type=int)
315 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
316 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
317 | help='use the multi-epochs-loader to save time at the beginning of every epoch')
318 | parser.add_argument('--log-wandb', action='store_true', default=False,
319 | help='log training and validation metrics to wandb')
320 |
321 |
322 | def _parse_args():
323 | # Do we have a config file to parse?
324 | args_config, remaining = config_parser.parse_known_args()
325 | if args_config.config:
326 | with open(args_config.config, 'r') as f:
327 | cfg = yaml.safe_load(f)
328 | parser.set_defaults(**cfg)
329 |
330 | # The main arg parser parses the rest of the args, the usual
331 | # defaults will have been overridden if config file specified.
332 | args = parser.parse_args(remaining)
333 |
334 | # Cache the args as a text string to save them in the output dir later
335 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
336 | return args, args_text
337 |
338 |
339 | def main():
340 | _logger.parent = None
341 | setup_default_logging(_logger, log_path='train.log')
342 | args, args_text = _parse_args()
343 |
344 | # parse nargs
345 | setting_num = 1
346 | setting_dicts = [dict()]
347 | for attr in _nargs_attrs:
348 | value = getattr(args, attr)
349 | if isinstance(value, list):
350 | if len(value) == 1:
351 | for d in setting_dicts:
352 | d[attr] = value[0]
353 | else:
354 | if setting_num == 1:
355 | setting_num = len(value)
356 | setting_dicts = [deepcopy(setting_dicts[0]) for _ in range(setting_num)]
357 | else: # ensure that args with multiple values have the same length
358 | assert setting_num == len(value)
359 | for i, v in enumerate(value):
360 | setting_dicts[i][attr] = v
361 | else:
362 | for d in setting_dicts:
363 | d[attr] = value
364 |
365 | # merge duplicating settings, only for non-nested dict
366 | setting_dicts = [dict(t) for t in sorted(list({tuple(sorted(d.items())) for d in setting_dicts}))]
367 |
368 | # merge settings with only different 'model_ema_decay'
369 | model_ema_decay_list = []
370 | assist_dict = dict()
371 | for i, d in enumerate(deepcopy(setting_dicts)):
372 | model_ema_decay_list.append(d['model_ema_decay'])
373 | del d['model_ema_decay']
374 | h = hash(tuple(sorted(d.items())))
375 | if h not in assist_dict:
376 | assist_dict[h] = [i]
377 | else:
378 | assist_dict[h].append(i)
379 |
380 | merged_setting_dict_list = []
381 | for v in assist_dict.values():
382 | d = setting_dicts[v[0]]
383 | d['model_ema_decay'] = tuple([model_ema_decay_list[index] for index in v])
384 | merged_setting_dict_list.append(d)
385 |
386 | # update
387 | setting_dicts = merged_setting_dict_list
388 | setting_num = len(setting_dicts)
389 |
390 | logger_list = [_logger]
391 | if setting_num > 1:
392 | if args.local_rank == 0:
393 | _logger.info(f'there are {setting_num} settings in total. creating individual logger for each setting')
394 |
395 | logger_list = []
396 | for i in range(setting_num):
397 | logger = logging.getLogger(f'train-setting-{i}')
398 | logger.parent = None
399 | setup_default_logging(logger, log_path=f'train-setting-{i}.log')
400 | logger_list.append(logger)
401 | if args.local_rank == 0:
402 | logger.info(f'settings of index {i}: ' +
403 | ', '.join(f"{k}={setting_dicts[i][k]}" for k in setting_dicts[i]))
404 | else:
405 | _logger.info(f'settings: ' + ', '.join(f"{k}={setting_dicts[0][k]}" for k in setting_dicts[0]))
406 |
407 | if args.log_wandb:
408 | if has_wandb:
409 | wandb.init(project=args.experiment, config=args)
410 | else:
411 | _logger.warning("You've requested to log metrics to wandb but package not found. "
412 | "Metrics not being logged to wandb, try `pip install wandb`")
413 |
414 | args.prefetcher = not args.no_prefetcher
415 | args.distributed = False
416 | if 'WORLD_SIZE' in os.environ:
417 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
418 | args.device = 'cuda:0'
419 | args.world_size = 1
420 | args.rank = 0 # global rank
421 | if args.distributed:
422 | assert 'RANK' in os.environ and 'WORLD_SIZE' in os.environ
423 | args.rank = int(os.environ["RANK"])
424 | args.world_size = int(os.environ['WORLD_SIZE'])
425 | args.device = int(os.environ['LOCAL_RANK'])
426 |
427 | torch.cuda.set_device(args.device)
428 |
429 | torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url,
430 | world_size=args.world_size, rank=args.rank)
431 | torch.distributed.barrier()
432 |
433 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
434 | % (args.rank, args.world_size))
435 |
436 | else:
437 | _logger.info('Training with a single process on 1 GPUs.')
438 | assert args.rank >= 0
439 |
440 | # resolve AMP arguments based on PyTorch / Apex availability
441 | use_amp = None
442 | if args.amp:
443 | # `--amp` chooses native amp before apex (APEX ver not actively maintained)
444 | if has_native_amp:
445 | args.native_amp = True
446 | elif has_apex:
447 | args.apex_amp = True
448 | if args.apex_amp and has_apex:
449 | use_amp = 'apex'
450 | elif args.native_amp and has_native_amp:
451 | use_amp = 'native'
452 | elif args.apex_amp or args.native_amp:
453 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
454 | "Install NVIDA apex or upgrade to PyTorch 1.6")
455 |
456 | random_seed(args.seed, args.rank)
457 |
458 | if args.fuser:
459 | set_jit_fuser(args.fuser)
460 |
461 | teacher = create_model(
462 | args.teacher,
463 | checkpoint_path=args.teacher_pretrained,
464 | num_classes=args.num_classes)
465 | register_forward(teacher)
466 | teacher = teacher.cuda()
467 | teacher.eval()
468 |
469 | model_list = []
470 | for i in range(setting_num):
471 | model = create_model(
472 | setting_dicts[i]['model'],
473 | pretrained=args.pretrained,
474 | num_classes=args.num_classes,
475 | drop_rate=setting_dicts[i]['drop'],
476 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
477 | drop_path_rate=setting_dicts[i]['drop_path'],
478 | drop_block_rate=args.drop_block,
479 | global_pool=args.gp,
480 | bn_momentum=args.bn_momentum,
481 | bn_eps=args.bn_eps,
482 | scriptable=args.torchscript,
483 | checkpoint_path=args.initial_checkpoint)
484 | register_forward(model)
485 | model_list.append(model)
486 |
487 | if args.local_rank == 0:
488 | logger_list[i].info(f'Model {safe_model_name(setting_dicts[i]["model"])} created, '
489 | f'param count:{sum([m.numel() for m in model.parameters()])}')
490 |
491 | # all settings must have the same data config to assure consistent teacher prediction
492 | if args.num_classes is None:
493 | assert hasattr(model_list[0],
494 | 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
495 | args.num_classes = model_list[0].num_classes # FIXME handle model default vs config num_classes more elegantly
496 |
497 | data_config = resolve_data_config(vars(args), model=model_list[0], verbose=args.local_rank == 0)
498 |
499 | # setup augmentation batch splits for contrastive loss or split bn
500 | num_aug_splits = 0
501 | if args.aug_splits > 0:
502 | assert args.aug_splits > 1, 'A split of 1 makes no sense'
503 | num_aug_splits = args.aug_splits
504 |
505 | for i, model in enumerate(model_list):
506 | # enable split bn (separate bn stats per batch-portion)
507 | if args.split_bn:
508 | assert num_aug_splits > 1 or args.resplit
509 | model = convert_splitbn_model(model, max(num_aug_splits, 2))
510 |
511 | # move model to GPU, enable channels last layout if set
512 | if setting_dicts[i]['kd_loss'] == 'correlation':
513 | kd_loss_fn = Correlation(feat_s_channel=config.get_pre_logit_dim(setting_dicts[i]['model']),
514 | feat_t_channel=config.get_pre_logit_dim(args.teacher))
515 | elif setting_dicts[i]['kd_loss'] == 'review':
516 | feat_index_s = config.get_used_feature_index(setting_dicts[i]['model'])
517 | feat_index_t = config.get_used_feature_index(args.teacher)
518 | in_channels = [config.get_feature_size_by_index(setting_dicts[i]['model'], j)[0] for j in feat_index_s]
519 | out_channels = [config.get_feature_size_by_index(args.teacher, j)[0] for j in feat_index_t]
520 | in_channels = in_channels + [config.get_pre_logit_dim(setting_dicts[i]['model'])]
521 | out_channels = out_channels + [config.get_pre_logit_dim(args.teacher)]
522 |
523 | kd_loss_fn = ReviewKD(feat_index_s, feat_index_t, in_channels, out_channels)
524 | elif setting_dicts[i]['kd_loss'] == 'rkd':
525 | kd_loss_fn = RKD()
526 | else:
527 | raise NotImplementedError(f'unknown kd loss {args.kd_loss}')
528 |
529 | model.kd_loss_fn = kd_loss_fn
530 | model.cuda()
531 | if args.channels_last:
532 | model = model.to(memory_format=torch.channels_last)
533 |
534 | # setup synchronized BatchNorm for distributed training
535 | if args.distributed and args.sync_bn:
536 | assert not args.split_bn
537 | if has_apex and use_amp == 'apex':
538 | # Apex SyncBN preferred unless native amp is activated
539 | model = convert_syncbn_model(model)
540 | else:
541 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
542 | if args.local_rank == 0:
543 | _logger.info(
544 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
545 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
546 |
547 | if args.torchscript:
548 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
549 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
550 | model = torch.jit.script(model)
551 |
552 | optimizer_list = []
553 | # if setting_num == 1:
554 | # optimizer_list.append(create_optimizer_v2(model_list[0], **optimizer_kwargs(cfg=args)))
555 | # else:
556 | for i in range(setting_num):
557 | optimizer_list.append(create_optimizer_v2(model_list[i],
558 | opt=setting_dicts[i]['opt'],
559 | lr=setting_dicts[i]['lr'],
560 | weight_decay=setting_dicts[i]['weight_decay'],
561 | momentum=args.momentum))
562 |
563 | # setup automatic mixed-precision (AMP) loss scaling and op casting
564 | amp_autocast = suppress # do nothing
565 | loss_scaler = None
566 | if use_amp == 'apex':
567 | new_model_list = []
568 | new_optimizer_list = []
569 | for model, optimizer in zip(model_list, optimizer_list):
570 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
571 | new_model_list.append(model)
572 | new_optimizer_list.append(optimizer)
573 | model_list = new_model_list
574 | optimizer_list = new_optimizer_list
575 | loss_scaler = ApexScaler()
576 | if args.local_rank == 0:
577 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
578 | elif use_amp == 'native':
579 | amp_autocast = torch.cuda.amp.autocast
580 | loss_scaler = NativeScaler()
581 | if args.local_rank == 0:
582 | _logger.info('Using native Torch AMP. Training in mixed precision.')
583 | else:
584 | if args.local_rank == 0:
585 | _logger.info('AMP not enabled. Training in float32.')
586 |
587 | # setup exponential moving average of model weights, SWA could be used here too
588 | model_ema_list = [(None,) for _ in range(setting_num)]
589 | if args.model_ema:
590 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
591 | for i in range(setting_num):
592 | model_emas = []
593 | for decay in setting_dicts[i]['model_ema_decay']:
594 | model_ema = ModelEmaV2(model_list[i], decay=decay,
595 | device='cpu' if args.model_ema_force_cpu else None)
596 | model_emas.append((model_ema, decay))
597 | model_ema_list[i] = tuple(model_emas)
598 |
599 | # setup distributed training
600 | if args.distributed:
601 | new_model_list = []
602 | for i in range(setting_num):
603 | if has_apex and use_amp == 'apex':
604 | # Apex DDP preferred unless native amp is activated
605 | if args.local_rank == 0:
606 | _logger.info("Using NVIDIA APEX DistributedDataParallel.")
607 | model = ApexDDP(model_list[i], delay_allreduce=True)
608 | else:
609 | if args.local_rank == 0:
610 | _logger.info("Using native Torch DistributedDataParallel.")
611 | model = NativeDDP(model_list[i], device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
612 | # NOTE: EMA model does not need to be wrapped by DDP
613 | new_model_list.append(model)
614 | model_list = new_model_list
615 |
616 | start_epoch = 0
617 | if args.start_epoch is not None:
618 | # a specified start_epoch will always override the resume epoch
619 | start_epoch = args.start_epoch
620 |
621 | # setup learning rate schedule and starting epoch
622 | lr_scheduler_list = []
623 | for i in range(setting_num):
624 | lr_scheduler, num_epochs = create_scheduler(args, optimizer_list[i])
625 | if lr_scheduler is not None and start_epoch > 0:
626 | lr_scheduler.step(start_epoch)
627 |
628 | lr_scheduler_list.append(lr_scheduler)
629 |
630 | if args.local_rank == 0:
631 | logger_list[i].info('Scheduled epochs: {}'.format(num_epochs))
632 |
633 | # create the train and eval datasets
634 | dataset_train = create_dataset(
635 | args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
636 | class_map=args.class_map,
637 | download=args.dataset_download,
638 | batch_size=args.batch_size,
639 | repeats=args.epoch_repeats)
640 | dataset_eval = create_dataset(
641 | args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
642 | class_map=args.class_map,
643 | download=args.dataset_download,
644 | batch_size=args.batch_size)
645 |
646 | # setup mixup / cutmix
647 | # smoothing is implemented in data loader when prefetcher=True,
648 | # so prefetcher should be turned off when multiple smoothing settings are used
649 | smoothing_setting_num = len(set([d['smoothing'] for d in setting_dicts]))
650 |
651 | collate_fn = None
652 | mixup_fn = None
653 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
654 | if mixup_active:
655 | if smoothing_setting_num == 1 and args.prefetcher:
656 | mixup_args = dict(
657 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
658 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
659 | label_smoothing=setting_dicts[0]['smoothing'], num_classes=args.num_classes)
660 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
661 | collate_fn = FastCollateMixup(**mixup_args)
662 | else:
663 | smoothings = tuple([d['smoothing'] for d in setting_dicts])
664 | mixup_args = dict(
665 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
666 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
667 | smoothings=smoothings, num_classes=args.num_classes)
668 | mixup_fn = MultiSmoothingMixup(**mixup_args)
669 |
670 | # wrap dataset in AugMix helper
671 | if num_aug_splits > 1:
672 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
673 |
674 | # create data loaders w/ augmentation pipeiine
675 | train_interpolation = args.train_interpolation
676 | if args.no_aug or not train_interpolation:
677 | train_interpolation = data_config['interpolation']
678 | loader_train = create_loader(
679 | dataset_train,
680 | input_size=data_config['input_size'],
681 | batch_size=args.batch_size,
682 | is_training=True,
683 | use_prefetcher=args.prefetcher,
684 | no_aug=args.no_aug,
685 | re_prob=args.reprob,
686 | re_mode=args.remode,
687 | re_count=args.recount,
688 | re_split=args.resplit,
689 | scale=args.scale,
690 | ratio=args.ratio,
691 | hflip=args.hflip,
692 | vflip=args.vflip,
693 | color_jitter=args.color_jitter,
694 | auto_augment=args.aa,
695 | num_aug_repeats=args.aug_repeats,
696 | num_aug_splits=num_aug_splits,
697 | interpolation=train_interpolation,
698 | mean=data_config['mean'],
699 | std=data_config['std'],
700 | num_workers=args.workers,
701 | distributed=args.distributed,
702 | collate_fn=collate_fn,
703 | pin_memory=args.pin_mem,
704 | use_multi_epochs_loader=args.use_multi_epochs_loader,
705 | worker_seeding=args.worker_seeding,
706 | )
707 |
708 | loader_eval = create_loader(
709 | dataset_eval,
710 | input_size=(3, 224, 224),
711 | batch_size=args.validation_batch_size or args.batch_size,
712 | is_training=False,
713 | use_prefetcher=args.prefetcher,
714 | interpolation=data_config['interpolation'],
715 | mean=data_config['mean'],
716 | std=data_config['std'],
717 | num_workers=args.workers,
718 | distributed=args.distributed,
719 | crop_pct=data_config['crop_pct'],
720 | pin_memory=args.pin_mem,
721 | )
722 |
723 | # setup loss function
724 | validate_loss_fn = nn.CrossEntropyLoss().cuda()
725 |
726 | train_loss_fn_list = []
727 | for i in range(setting_num):
728 | if args.jsd_loss:
729 | assert num_aug_splits > 1 # JSD only valid with aug splits set
730 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=setting_dicts[i]['smoothing'])
731 | elif mixup_active:
732 | # smoothing is handled with mixup target transform which outputs sparse, soft targets
733 | if setting_dicts[i]['bce_loss']:
734 | train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
735 | else:
736 | train_loss_fn = SoftTargetCrossEntropy()
737 | elif setting_dicts[i]['smoothing']:
738 | if setting_dicts[i]['bce_loss']:
739 | train_loss_fn = BinaryCrossEntropy(smoothing=setting_dicts[i]['smoothing'],
740 | target_threshold=args.bce_target_thresh)
741 | else:
742 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=setting_dicts[i]['smoothing'])
743 | else:
744 | train_loss_fn = nn.CrossEntropyLoss()
745 | train_loss_fn = train_loss_fn.cuda()
746 | train_loss_fn_list.append(train_loss_fn)
747 |
748 | teacher_resizer = student_resizer = None
749 | if args.teacher_resize is not None:
750 | teacher_resizer = transforms.Resize(args.teacher_resize).cuda()
751 | if args.student_resize is not None:
752 | student_resizer = transforms.Resize(args.student_resize).cuda()
753 |
754 | # setup checkpoint saver and eval metric tracking
755 | eval_metric = args.eval_metric
756 | saver_list = [None for _ in range(setting_num)]
757 | ema_saver_list = [tuple([None for _ in range(len(model_ema_list[i]))]) for i in range(setting_num)]
758 | for ema, saver in zip(model_ema_list, ema_saver_list):
759 | assert len(ema) == len(saver)
760 | output_dir_list = [None for _ in range(setting_num)]
761 | for i in range(setting_num):
762 | if args.rank == 0:
763 | if args.experiment:
764 | exp_name = args.experiment + f'-setting-{i}'
765 | else:
766 | exp_name = '-'.join([
767 | datetime.now().strftime("%Y%m%d-%H%M%S"),
768 | safe_model_name(setting_dicts[i]['model']),
769 | str(data_config['input_size'][-1]),
770 | f'-setting-{i}'
771 | ])
772 | output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
773 | output_dir_list[i] = output_dir
774 | decreasing = True if eval_metric == 'loss' else False
775 | saver_dir = os.path.join(output_dir, 'checkpoint')
776 | os.makedirs(saver_dir)
777 | saver = CheckpointSaverWithLogger(
778 | logger=logger_list[i], model=model_list[i], optimizer=optimizer_list[i], args=args,
779 | amp_scaler=loss_scaler, checkpoint_dir=saver_dir, recovery_dir=saver_dir,
780 | decreasing=decreasing, max_history=args.checkpoint_hist)
781 | saver_list[i] = saver
782 | if model_ema_list[i][0] is not None:
783 | ema_savers = []
784 | for ema, decay in model_ema_list[i]:
785 | ema_saver_dir = os.path.join(output_dir, f'ema{decay}_checkpoint')
786 | os.makedirs(ema_saver_dir)
787 | ema_saver = CheckpointSaverWithLogger(
788 | logger=logger_list[i], model=model_list[i], optimizer=optimizer_list[i], args=args,
789 | model_ema=ema, amp_scaler=loss_scaler, checkpoint_dir=ema_saver_dir,
790 | recovery_dir=ema_saver_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
791 | ema_savers.append(ema_saver)
792 | ema_saver_list[i] = tuple(ema_savers)
793 | with open(os.path.join(get_outdir(args.output if args.output else './output/train'),
794 | 'args.yaml'), 'w') as f:
795 | f.write(args_text)
796 |
797 | best_metric_list = [None for _ in range(setting_num)]
798 | best_epoch_list = [None for _ in range(setting_num)]
799 | try:
800 | tp = TimePredictor(num_epochs - start_epoch)
801 | for epoch in range(start_epoch, num_epochs):
802 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
803 | loader_train.sampler.set_epoch(epoch)
804 |
805 | ori_loss_weight_list = tuple([d['ori_loss_weight'] for d in setting_dicts])
806 | kd_loss_weight_list = tuple([d['kd_loss_weight'] for d in setting_dicts])
807 | clip_grad_list = tuple([d['clip_grad'] for d in setting_dicts])
808 | train_metrics_list = train_one_epoch(
809 | epoch, model_list, teacher, loader_train, optimizer_list, train_loss_fn_list, args,
810 | lr_scheduler_list=lr_scheduler_list, amp_autocast=amp_autocast,
811 | loss_scaler=loss_scaler, model_ema_list=model_ema_list, mixup_fn=mixup_fn,
812 | teacher_resizer=teacher_resizer, student_resizer=student_resizer,
813 | ori_loss_weight_list=ori_loss_weight_list, kd_loss_weight_list=kd_loss_weight_list,
814 | clip_grad_list=clip_grad_list, logger_list=logger_list)
815 |
816 | for i in range(setting_num):
817 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
818 | if args.local_rank == 0:
819 | logger_list[i].info("Distributing BatchNorm running means and vars")
820 | distribute_bn(model_list[i], args.world_size, args.dist_bn == 'reduce')
821 |
822 | is_eval = epoch > int(args.eval_interval_end * args.epochs) or epoch % args.eval_interval == 0
823 | if is_eval:
824 | eval_metrics = validate(model_list[i], loader_eval, validate_loss_fn, args,
825 | logger=logger_list[i], amp_autocast=amp_autocast)
826 |
827 | if saver_list[i] is not None:
828 | # save proper checkpoint with eval metric
829 | save_metric = eval_metrics[eval_metric]
830 | best_metric, best_epoch = saver_list[i].save_checkpoint(epoch, metric=save_metric)
831 | best_metric_list[i] = best_metric
832 | best_epoch_list[i] = best_epoch
833 |
834 | if model_ema_list[i][0] is not None and not args.model_ema_force_cpu:
835 | for j, ((ema, decay), saver) in enumerate(zip(model_ema_list[i], ema_saver_list[i])):
836 |
837 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
838 | distribute_bn(ema, args.world_size, args.dist_bn == 'reduce')
839 |
840 | ema_eval_metrics = validate(ema.module, loader_eval, validate_loss_fn, args,
841 | logger=logger_list[i], amp_autocast=amp_autocast,
842 | log_suffix=f' (EMA {decay:.5f})')
843 |
844 | if saver is not None:
845 | # save proper checkpoint with eval metric
846 | save_metric = ema_eval_metrics[eval_metric]
847 | saver.save_checkpoint(epoch, metric=save_metric)
848 |
849 | if output_dir_list[i] is not None:
850 | update_summary(
851 | epoch, train_metrics_list[i], eval_metrics, os.path.join(output_dir_list[i], 'summary.csv'),
852 | write_header=best_metric_list[i] is None, log_wandb=args.log_wandb and has_wandb)
853 |
854 | metrics = eval_metrics[eval_metric]
855 | else:
856 | metrics = None
857 |
858 | if lr_scheduler_list[i] is not None:
859 | # step LR for next epoch
860 | lr_scheduler_list[i].step(epoch + 1, metrics)
861 |
862 | tp.update()
863 | if args.rank == 0:
864 | print(f'Will finish at {tp.get_pred_text()}')
865 | print(f'Avg running time of latest {len(tp.time_list)} epochs: {np.mean(tp.time_list):.2f}s/ep.')
866 |
867 | except KeyboardInterrupt:
868 | pass
869 |
870 | for i in range(setting_num):
871 | if best_metric_list[i] is not None:
872 | logger_list[i].info('*** Best metric: {0} (epoch {1})'.format(best_metric_list[i], best_epoch_list[i]))
873 |
874 | if args.rank == 0:
875 | if setting_num == 1:
876 | os.system(f'mv train.log {args.output}')
877 | else:
878 | os.system(f'mv train.log {args.output}')
879 | for i in range(setting_num):
880 | os.system(f'mv train-setting-{i}.log {args.output}')
881 |
882 |
883 | def train_one_epoch(
884 | epoch, model_list, teacher, loader, optimizer_list, loss_fn_list, args,
885 | lr_scheduler_list=(None,), amp_autocast=suppress, loss_scaler=None, model_ema_list=(None,), mixup_fn=None,
886 | teacher_resizer=None, student_resizer=None, ori_loss_weight_list=(None,),
887 | kd_loss_weight_list=(None,), clip_grad_list=(None,), logger_list=(None,)):
888 | setting_num = len(model_list)
889 |
890 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
891 | if args.prefetcher and loader.mixup_enabled:
892 | loader.mixup_enabled = False
893 | elif mixup_fn is not None:
894 | mixup_fn.mixup_enabled = False
895 |
896 | second_order_list = [hasattr(o, 'is_second_order') and o.is_second_order for o in optimizer_list]
897 | batch_time_m = AverageMeter()
898 | data_time_m = AverageMeter()
899 | losses_m_list = [AverageMeter() for _ in range(setting_num)]
900 | losses_ori_m_list = [AverageMeter() for _ in range(setting_num)]
901 | losses_kd_m_list = [AverageMeter() for _ in range(setting_num)]
902 |
903 | for model in model_list:
904 | model.train()
905 |
906 | end = time.time()
907 | last_idx = len(loader) - 1
908 | num_updates = epoch * len(loader)
909 | for batch_idx, (input, target) in enumerate(loader):
910 | last_batch = batch_idx == last_idx
911 | data_time_m.update(time.time() - end)
912 | if not args.prefetcher:
913 | input, target = input.cuda(), target.cuda()
914 | if mixup_fn is not None:
915 | input, targets = mixup_fn(input, target)
916 | else:
917 | targets = [target for _ in range(setting_num)]
918 | else:
919 | targets = [target for _ in range(setting_num)]
920 |
921 | if args.channels_last:
922 | input = input.contiguous(memory_format=torch.channels_last)
923 |
924 | # teacher forward
925 | with amp_autocast():
926 | if teacher_resizer is not None:
927 | teacher_input = teacher_resizer(input)
928 | else:
929 | teacher_input = input
930 |
931 | if args.economic:
932 | torch.cuda.empty_cache()
933 | with torch.no_grad():
934 | output_t, feat_t = teacher(teacher_input, requires_feat=True)
935 |
936 | if args.economic:
937 | torch.cuda.empty_cache()
938 |
939 | # student forward
940 | for i in range(setting_num):
941 |
942 | if setting_num != 1: # more than 1 model
943 | torch.cuda.empty_cache()
944 |
945 | with amp_autocast():
946 | if student_resizer is not None:
947 | student_input = student_resizer(input)
948 | else:
949 | student_input = input
950 |
951 | output, feat = model_list[i](student_input, requires_feat=True)
952 |
953 | loss_ori = ori_loss_weight_list[i] * loss_fn_list[i](output, targets[i])
954 |
955 | try:
956 | kd_loss_fn = model_list[i].module.kd_loss_fn
957 | except AttributeError:
958 | kd_loss_fn = model_list[i].kd_loss_fn
959 |
960 | loss_kd = kd_loss_weight_list[i] * kd_loss_fn(z_s=output, z_t=output_t.detach(),
961 | target=targets[i],
962 | epoch=epoch,
963 | feature_student=process_feat(kd_loss_fn, feat),
964 | feature_teacher=process_feat(kd_loss_fn, feat_t))
965 | loss = loss_ori + loss_kd
966 |
967 | if not args.distributed:
968 | losses_m_list[i].update(loss.item(), input.size(0))
969 | losses_ori_m_list[i].update(loss_ori.item(), input.size(0))
970 | losses_kd_m_list[i].update(loss_kd.item(), input.size(0))
971 |
972 | optimizer_list[i].zero_grad()
973 | if loss_scaler is not None:
974 | loss_scaler(
975 | loss, optimizer_list[i],
976 | clip_grad=clip_grad_list[i], clip_mode=args.clip_mode,
977 | parameters=model_parameters(model_list[i], exclude_head='agc' in args.clip_mode),
978 | create_graph=second_order_list[i])
979 | else:
980 | loss.backward(create_graph=second_order_list[i])
981 | if clip_grad_list[i] is not None:
982 | dispatch_clip_grad(
983 | model_parameters(model_list[i], exclude_head='agc' in args.clip_mode),
984 | value=clip_grad_list[i], mode=args.clip_mode)
985 | optimizer_list[i].step()
986 |
987 | if model_ema_list[i][0] is not None:
988 | for ema, _ in model_ema_list[i]:
989 | ema.update(model_list[i])
990 |
991 | torch.cuda.synchronize()
992 | batch_time_m.update(time.time() - end)
993 | if last_batch or batch_idx % args.log_interval == 0:
994 | lrl = [param_group['lr'] for param_group in optimizer_list[i].param_groups]
995 | lr = sum(lrl) / len(lrl)
996 |
997 | if args.distributed:
998 | reduced_loss = reduce_tensor(loss.data, args.world_size)
999 | reduced_loss_ori = reduce_tensor(loss_ori.data, args.world_size)
1000 | reduced_loss_kd = reduce_tensor(loss_kd.data, args.world_size)
1001 | losses_m_list[i].update(reduced_loss.item(), input.size(0))
1002 | losses_ori_m_list[i].update(reduced_loss_ori.item(), input.size(0))
1003 | losses_kd_m_list[i].update(reduced_loss_kd.item(), input.size(0))
1004 |
1005 | if args.local_rank == 0:
1006 | logger_list[i].info(
1007 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
1008 | 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
1009 | 'Loss_ori: {loss_ori.val:#.4g} ({loss_ori.avg:#.3g}) '
1010 | 'Loss_kd: {loss_kd.val:#.4g} ({loss_kd.avg:#.3g}) '
1011 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
1012 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
1013 | 'LR: {lr:.3e} '
1014 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
1015 | epoch,
1016 | batch_idx, len(loader),
1017 | 100. * batch_idx / last_idx,
1018 | loss=losses_m_list[i],
1019 | loss_ori=losses_ori_m_list[i],
1020 | loss_kd=losses_kd_m_list[i],
1021 | batch_time=batch_time_m,
1022 | rate=input.size(0) * args.world_size / batch_time_m.val,
1023 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
1024 | lr=lr,
1025 | data_time=data_time_m))
1026 |
1027 | if setting_num != 1: # more than 1 model
1028 | torch.cuda.empty_cache()
1029 |
1030 | num_updates += 1
1031 | for i in range(setting_num):
1032 | if lr_scheduler_list[i] is not None:
1033 | lr_scheduler_list[i].step_update(num_updates=num_updates, metric=losses_m_list[i].avg)
1034 |
1035 | end = time.time()
1036 | # end for
1037 |
1038 | for i in range(setting_num):
1039 | if hasattr(optimizer_list[i], 'sync_lookahead'):
1040 | optimizer_list[i].sync_lookahead()
1041 |
1042 | return [OrderedDict([('loss', losses_m.avg)]) for losses_m in losses_m_list]
1043 |
1044 |
1045 | def validate(model, loader, loss_fn, args, logger, amp_autocast=suppress, log_suffix=''):
1046 | batch_time_m = AverageMeter()
1047 | losses_m = AverageMeter()
1048 | top1_m = AverageMeter()
1049 | top5_m = AverageMeter()
1050 |
1051 | model.eval()
1052 |
1053 | end = time.time()
1054 | last_idx = len(loader) - 1
1055 | with torch.no_grad():
1056 | for batch_idx, (input, target) in enumerate(loader):
1057 | last_batch = batch_idx == last_idx
1058 | if not args.prefetcher:
1059 | input = input.cuda()
1060 | target = target.cuda()
1061 | if args.channels_last:
1062 | input = input.contiguous(memory_format=torch.channels_last)
1063 |
1064 | with amp_autocast():
1065 | output = model(input)
1066 | if isinstance(output, (tuple, list)):
1067 | output = output[0]
1068 |
1069 | # augmentation reduction
1070 | reduce_factor = args.tta
1071 | if reduce_factor > 1:
1072 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
1073 | target = target[0:target.size(0):reduce_factor]
1074 |
1075 | loss = loss_fn(output, target)
1076 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
1077 |
1078 | if args.distributed:
1079 | reduced_loss = reduce_tensor(loss.data, args.world_size)
1080 | acc1 = reduce_tensor(acc1, args.world_size)
1081 | acc5 = reduce_tensor(acc5, args.world_size)
1082 | else:
1083 | reduced_loss = loss.data
1084 |
1085 | torch.cuda.synchronize()
1086 |
1087 | losses_m.update(reduced_loss.item(), input.size(0))
1088 | top1_m.update(acc1.item(), output.size(0))
1089 | top5_m.update(acc5.item(), output.size(0))
1090 |
1091 | batch_time_m.update(time.time() - end)
1092 | end = time.time()
1093 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
1094 | log_name = 'Test' + log_suffix
1095 | logger.info(
1096 | '{0}: [{1:>4d}/{2}] '
1097 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
1098 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
1099 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
1100 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
1101 | log_name, batch_idx, last_idx, batch_time=batch_time_m,
1102 | loss=losses_m, top1=top1_m, top5=top5_m))
1103 |
1104 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
1105 |
1106 | return metrics
1107 |
1108 |
1109 | if __name__ == '__main__':
1110 | main()
1111 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import time
4 | from datetime import datetime
5 |
6 | import numpy as np
7 |
8 | from timm.data import ImageDataset
9 | from timm.data.mixup import Mixup, mixup_target
10 | from timm.utils import CheckpointSaver, unwrap_model
11 |
12 |
13 | class ImageNetInstanceSample(ImageDataset):
14 | """: Folder datasets which returns (img, label, index, contrast_index):
15 | """
16 |
17 | def __init__(self, root, name, class_map, load_bytes, is_sample=False, k=4096, **kwargs):
18 | super().__init__(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
19 | self.k = k
20 | self.is_sample = is_sample
21 | if self.is_sample:
22 | print('preparing contrastive data...')
23 | num_classes = 1000
24 | num_samples = len(self.parser)
25 | label = np.zeros(num_samples, dtype=np.int32)
26 | for i in range(num_samples):
27 | _, target = self.parser[i]
28 | label[i] = target
29 |
30 | self.cls_positive = [[] for _ in range(num_classes)]
31 | for i in range(num_samples):
32 | self.cls_positive[label[i]].append(i)
33 |
34 | self.cls_negative = [[] for _ in range(num_classes)]
35 | for i in range(num_classes):
36 | for j in range(num_classes):
37 | if j == i:
38 | continue
39 | self.cls_negative[i].extend(self.cls_positive[j])
40 |
41 | self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)]
42 | self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)]
43 | print('done.')
44 |
45 | def __getitem__(self, index):
46 | """
47 | Args:
48 | index (int): Index
49 | Returns:
50 | tuple: (image, target) where target is class_index of the target class.
51 | """
52 | img, target = super().__getitem__(index)
53 |
54 | if self.is_sample:
55 | # sample contrastive examples
56 | pos_idx = index
57 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True)
58 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
59 | return img, target, index, sample_idx
60 | else:
61 | return img, target, index
62 |
63 |
64 | class MultiSmoothingMixup(Mixup):
65 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
66 | mode='batch', correct_lam=True, smoothings=(0.1,), num_classes=1000):
67 | super(MultiSmoothingMixup, self).__init__(mixup_alpha, cutmix_alpha, cutmix_minmax, prob, switch_prob,
68 | mode, correct_lam, 0, num_classes)
69 | self.smoothings = smoothings
70 |
71 | def __call__(self, x, target):
72 | assert len(x) % 2 == 0, 'Batch size should be even when using this'
73 | if self.mode == 'elem':
74 | lam = self._mix_elem(x)
75 | elif self.mode == 'pair':
76 | lam = self._mix_pair(x)
77 | else:
78 | lam = self._mix_batch(x)
79 | targets = []
80 | for smoothing in self.smoothings:
81 | targets.append(mixup_target(target, self.num_classes, lam, smoothing, x.device))
82 | return x, targets
83 |
84 |
85 | class CheckpointSaverWithLogger(CheckpointSaver):
86 | def __init__(
87 | self,
88 | logger,
89 | model,
90 | optimizer,
91 | args=None,
92 | model_ema=None,
93 | amp_scaler=None,
94 | checkpoint_prefix='checkpoint',
95 | recovery_prefix='recovery',
96 | checkpoint_dir='',
97 | recovery_dir='',
98 | decreasing=False,
99 | max_history=10,
100 | unwrap_fn=unwrap_model):
101 | super(CheckpointSaverWithLogger, self).__init__(model, optimizer, args, model_ema, amp_scaler,
102 | checkpoint_prefix, recovery_prefix, checkpoint_dir,
103 | recovery_dir, decreasing, max_history, unwrap_fn)
104 | self.logger = logger
105 |
106 | def save_checkpoint(self, epoch, metric=None):
107 | assert epoch >= 0
108 | tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
109 | last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
110 | self._save(tmp_save_path, epoch, metric)
111 | if os.path.exists(last_save_path):
112 | os.unlink(last_save_path) # required for Windows support.
113 | os.rename(tmp_save_path, last_save_path)
114 | worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
115 | if (len(self.checkpoint_files) < self.max_history
116 | or metric is None or self.cmp(metric, worst_file[1])):
117 | if len(self.checkpoint_files) >= self.max_history:
118 | self._cleanup_checkpoints(1)
119 | filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
120 | save_path = os.path.join(self.checkpoint_dir, filename)
121 | os.link(last_save_path, save_path)
122 | self.checkpoint_files.append((save_path, metric))
123 | self.checkpoint_files = sorted(
124 | self.checkpoint_files, key=lambda x: x[1],
125 | reverse=not self.decreasing) # sort in descending order if a lower metric is not better
126 |
127 | checkpoints_str = "Current checkpoints:\n"
128 | for c in self.checkpoint_files:
129 | checkpoints_str += ' {}\n'.format(c)
130 | self.logger.info(checkpoints_str)
131 |
132 | if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
133 | self.best_epoch = epoch
134 | self.best_metric = metric
135 | best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
136 | if os.path.exists(best_save_path):
137 | os.unlink(best_save_path)
138 | os.link(last_save_path, best_save_path)
139 |
140 | return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
141 |
142 | def _cleanup_checkpoints(self, trim=0):
143 | trim = min(len(self.checkpoint_files), trim)
144 | delete_index = self.max_history - trim
145 | if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
146 | return
147 | to_delete = self.checkpoint_files[delete_index:]
148 | for d in to_delete:
149 | try:
150 | self.logger.debug("Cleaning checkpoint: {}".format(d))
151 | os.remove(d[0])
152 | except Exception as e:
153 | self.logger.error("Exception '{}' while deleting checkpoint".format(e))
154 | self.checkpoint_files = self.checkpoint_files[:delete_index]
155 |
156 | def save_recovery(self, epoch, batch_idx=0):
157 | assert epoch >= 0
158 | filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
159 | save_path = os.path.join(self.recovery_dir, filename)
160 | self._save(save_path, epoch)
161 | if os.path.exists(self.last_recovery_file):
162 | try:
163 | self.logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
164 | os.remove(self.last_recovery_file)
165 | except Exception as e:
166 | self.logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
167 | self.last_recovery_file = self.curr_recovery_file
168 | self.curr_recovery_file = save_path
169 |
170 |
171 | def setup_default_logging(logger, default_level=logging.INFO, log_path=''):
172 | console_handler = logging.StreamHandler()
173 | console_formatter = logging.Formatter("%(name)15s: %(message)s")
174 | console_handler.setFormatter(console_formatter)
175 | # console_handler.setFormatter(FormatterNoInfo())
176 | logger.addHandler(console_handler)
177 | logger.setLevel(default_level)
178 | if log_path:
179 | file_handler = logging.FileHandler(log_path)
180 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s")
181 | file_handler.setFormatter(file_formatter)
182 | logger.addHandler(file_handler)
183 |
184 |
185 | class TimePredictor:
186 | def __init__(self, steps, most_recent=30, drop_first=True):
187 | self.init_time = time.time()
188 | self.steps = steps
189 | self.most_recent = most_recent
190 | self.drop_first = drop_first # drop iter 0
191 |
192 | self.time_list = []
193 | self.temp_time = self.init_time
194 |
195 | def update(self):
196 | time_interval = time.time() - self.temp_time
197 | self.time_list.append(time_interval)
198 |
199 | if self.drop_first and len(self.time_list) > 1:
200 | self.time_list = self.time_list[1:]
201 | self.drop_first = False
202 |
203 | self.time_list = self.time_list[-self.most_recent:]
204 | self.temp_time = time.time()
205 |
206 | def get_pred_text(self):
207 | single_step_time = np.mean(self.time_list)
208 | end_timestamp = self.init_time + single_step_time * self.steps
209 | return datetime.fromtimestamp(end_timestamp).strftime('%Y-%m-%d %H:%M:%S')
210 |
211 |
212 | def process_feat(distiller, source_feat):
213 | if getattr(distiller, 'pre_act_feat', False):
214 | feat = source_feat[0]
215 | else:
216 | feat = source_feat[1]
217 | return feat
218 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """ ImageNet Validation Script
3 |
4 | This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
5 | models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
6 | canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
7 |
8 | Hacked together by Ross Wightman (https://github.com/rwightman)
9 | """
10 | import argparse
11 | import os
12 | import csv
13 | import glob
14 | import json
15 | import sys
16 | import time
17 | import logging
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.parallel
21 | from collections import OrderedDict
22 | from contextlib import suppress
23 |
24 | if os.path.exists('./timm'):
25 | sys.path.insert(0, './')
26 |
27 | from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
28 | from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
29 | from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
30 | import models
31 |
32 | has_apex = False
33 | try:
34 | from apex import amp
35 | has_apex = True
36 | except ImportError:
37 | pass
38 |
39 | has_native_amp = False
40 | try:
41 | if getattr(torch.cuda.amp, 'autocast') is not None:
42 | has_native_amp = True
43 | except AttributeError:
44 | pass
45 |
46 | torch.backends.cudnn.benchmark = True
47 | _logger = logging.getLogger('validate')
48 |
49 |
50 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
51 | parser.add_argument('data', metavar='DIR',
52 | help='path to dataset')
53 | parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
54 | help='model architecture (default: dpn92)')
55 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
56 | help='path to latest checkpoint (default: none)')
57 | parser.add_argument('--use-ema', dest='use_ema', action='store_true',
58 | help='use ema version of weights if present')
59 |
60 |
61 | parser.add_argument('--dataset', '-d', metavar='NAME', default='',
62 | help='dataset type (default: ImageFolder/ImageTar if empty)')
63 | parser.add_argument('--split', metavar='NAME', default='validation',
64 | help='dataset split (default: validation)')
65 | parser.add_argument('--dataset-download', action='store_true', default=False,
66 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
67 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
68 | help='number of data loading workers (default: 2)')
69 | parser.add_argument('-b', '--batch-size', default=256, type=int,
70 | metavar='N', help='mini-batch size (default: 256)')
71 | parser.add_argument('--img-size', default=None, type=int,
72 | metavar='N', help='Input image dimension, uses model default if empty')
73 | parser.add_argument('--input-size', default=None, nargs=3, type=int,
74 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
75 | parser.add_argument('--crop-pct', default=None, type=float,
76 | metavar='N', help='Input image center crop pct')
77 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
78 | help='Override mean pixel value of dataset')
79 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
80 | help='Override std deviation of of dataset')
81 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
82 | help='Image resize interpolation type (overrides model)')
83 | parser.add_argument('--num-classes', type=int, default=1000,
84 | help='Number classes in dataset')
85 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
86 | help='path to class to idx mapping file (default: "")')
87 | parser.add_argument('--gp', default=None, type=str, metavar='POOL',
88 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
89 | parser.add_argument('--log-freq', default=50, type=int,
90 | metavar='N', help='batch logging frequency (default: 10)')
91 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
92 | help='use pre-trained model')
93 | parser.add_argument('--num-gpu', type=int, default=1,
94 | help='Number of GPUS to use')
95 | parser.add_argument('--test-pool', dest='test_pool', action='store_true',
96 | help='enable test time pool')
97 | parser.add_argument('--no-prefetcher', action='store_true', default=False,
98 | help='disable fast prefetcher')
99 | parser.add_argument('--pin-mem', action='store_true', default=False,
100 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
101 | parser.add_argument('--channels-last', action='store_true', default=False,
102 | help='Use channels_last memory layout')
103 | parser.add_argument('--amp', action='store_true', default=False,
104 | help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
105 | parser.add_argument('--apex-amp', action='store_true', default=False,
106 | help='Use NVIDIA Apex AMP mixed precision')
107 | parser.add_argument('--native-amp', action='store_true', default=False,
108 | help='Use Native Torch AMP mixed precision')
109 | parser.add_argument('--tf-preprocessing', action='store_true', default=False,
110 | help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
111 | parser.add_argument('--torchscript', dest='torchscript', action='store_true',
112 | help='convert model torchscript for inference')
113 | parser.add_argument('--fuser', default='', type=str,
114 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
115 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
116 | help='Output csv file for validation results (summary)')
117 | parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
118 | help='Real labels JSON file for imagenet evaluation')
119 | parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
120 | help='Valid label indices txt file for validation of partial label space')
121 |
122 |
123 | def validate(args):
124 | # might as well try to validate something
125 | args.pretrained = args.pretrained or not args.checkpoint
126 | args.prefetcher = not args.no_prefetcher
127 | amp_autocast = suppress # do nothing
128 | if args.amp:
129 | if has_native_amp:
130 | args.native_amp = True
131 | elif has_apex:
132 | args.apex_amp = True
133 | else:
134 | _logger.warning("Neither APEX or Native Torch AMP is available.")
135 | assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
136 | if args.native_amp:
137 | amp_autocast = torch.cuda.amp.autocast
138 | _logger.info('Validating in mixed precision with native PyTorch AMP.')
139 | elif args.apex_amp:
140 | _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
141 | else:
142 | _logger.info('Validating in float32. AMP not enabled.')
143 |
144 | # create model
145 | model = create_model(
146 | args.model,
147 | pretrained=args.pretrained,
148 | num_classes=args.num_classes,
149 | in_chans=3,
150 | global_pool=args.gp,
151 | scriptable=args.torchscript)
152 | if args.num_classes is None:
153 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
154 | args.num_classes = model.num_classes
155 |
156 | if args.checkpoint:
157 | incompatible_keys = load_checkpoint(model, args.checkpoint, args.use_ema, strict=False)
158 | print(f'incompatible_keys: {incompatible_keys}')
159 |
160 | param_count = sum([m.numel() for m in model.parameters()])
161 | _logger.info('Model %s created, param count: %d' % (args.model, param_count))
162 |
163 | data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
164 | test_time_pool = False
165 | if args.test_pool:
166 | model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
167 |
168 | if args.torchscript:
169 | torch.jit.optimized_execution(True)
170 | model = torch.jit.script(model)
171 |
172 | model = model.cuda()
173 | if args.apex_amp:
174 | model = amp.initialize(model, opt_level='O1')
175 |
176 | if args.channels_last:
177 | model = model.to(memory_format=torch.channels_last)
178 |
179 | if args.num_gpu > 1:
180 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
181 |
182 | criterion = nn.CrossEntropyLoss().cuda()
183 |
184 | dataset = create_dataset(
185 | root=args.data, name=args.dataset, split=args.split,
186 | download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)
187 |
188 | if args.valid_labels:
189 | with open(args.valid_labels, 'r') as f:
190 | valid_labels = {int(line.rstrip()) for line in f}
191 | valid_labels = [i in valid_labels for i in range(args.num_classes)]
192 | else:
193 | valid_labels = None
194 |
195 | if args.real_labels:
196 | real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
197 | else:
198 | real_labels = None
199 |
200 | crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
201 | loader = create_loader(
202 | dataset,
203 | input_size=data_config['input_size'],
204 | batch_size=args.batch_size,
205 | use_prefetcher=args.prefetcher,
206 | interpolation=data_config['interpolation'],
207 | mean=data_config['mean'],
208 | std=data_config['std'],
209 | num_workers=args.workers,
210 | crop_pct=crop_pct,
211 | pin_memory=args.pin_mem,
212 | tf_preprocessing=args.tf_preprocessing)
213 |
214 | batch_time = AverageMeter()
215 | losses = AverageMeter()
216 | top1 = AverageMeter()
217 | top5 = AverageMeter()
218 |
219 | model.eval()
220 | with torch.no_grad():
221 | # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
222 | input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
223 | if args.channels_last:
224 | input = input.contiguous(memory_format=torch.channels_last)
225 | with amp_autocast():
226 | model(input)
227 |
228 | end = time.time()
229 | for batch_idx, (input, target) in enumerate(loader):
230 | if args.no_prefetcher:
231 | target = target.cuda()
232 | input = input.cuda()
233 | if args.channels_last:
234 | input = input.contiguous(memory_format=torch.channels_last)
235 |
236 | # compute output
237 | with amp_autocast():
238 | output = model(input)
239 |
240 | if valid_labels is not None:
241 | output = output[:, valid_labels]
242 | loss = criterion(output, target)
243 |
244 | if real_labels is not None:
245 | real_labels.add_result(output)
246 |
247 | # measure accuracy and record loss
248 | acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
249 | losses.update(loss.item(), input.size(0))
250 | top1.update(acc1.item(), input.size(0))
251 | top5.update(acc5.item(), input.size(0))
252 |
253 | # measure elapsed time
254 | batch_time.update(time.time() - end)
255 | end = time.time()
256 |
257 | if batch_idx % args.log_freq == 0:
258 | _logger.info(
259 | 'Test: [{0:>4d}/{1}] '
260 | 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
261 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
262 | 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
263 | 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
264 | batch_idx, len(loader), batch_time=batch_time,
265 | rate_avg=input.size(0) / batch_time.avg,
266 | loss=losses, top1=top1, top5=top5))
267 |
268 | if real_labels is not None:
269 | # real labels mode replaces topk values at the end
270 | top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
271 | else:
272 | top1a, top5a = top1.avg, top5.avg
273 | results = OrderedDict(
274 | model=args.model,
275 | top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
276 | top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
277 | param_count=round(param_count / 1e6, 2),
278 | img_size=data_config['input_size'][-1],
279 | cropt_pct=crop_pct,
280 | interpolation=data_config['interpolation'])
281 |
282 | _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
283 | results['top1'], results['top1_err'], results['top5'], results['top5_err']))
284 |
285 | return results
286 |
287 |
288 | def _try_run(args, initial_batch_size):
289 | batch_size = initial_batch_size
290 | results = OrderedDict()
291 | error_str = 'Unknown'
292 | while batch_size >= 1:
293 | args.batch_size = batch_size
294 | torch.cuda.empty_cache()
295 | try:
296 | results = validate(args)
297 | return results
298 | except RuntimeError as e:
299 | error_str = str(e)
300 | if 'channels_last' in error_str:
301 | break
302 | _logger.warning(f'"{error_str}" while running validation. Reducing batch size to {batch_size} for retry.')
303 | batch_size = batch_size // 2
304 | results['error'] = error_str
305 | _logger.error(f'{args.model} failed to validate ({error_str}).')
306 | return results
307 |
308 |
309 | def main():
310 | setup_default_logging()
311 | args = parser.parse_args()
312 | model_cfgs = []
313 | model_names = []
314 | if os.path.isdir(args.checkpoint):
315 | # validate all checkpoints in a path with same model
316 | checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
317 | checkpoints += glob.glob(args.checkpoint + '/*.pth')
318 | model_names = list_models(args.model)
319 | model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
320 | else:
321 | if args.model == 'all':
322 | # validate all models in a list of names with pretrained checkpoints
323 | args.pretrained = True
324 | model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino'])
325 | model_cfgs = [(n, '') for n in model_names]
326 | elif not is_model(args.model):
327 | # model name doesn't exist, try as wildcard filter
328 | model_names = list_models(args.model)
329 | model_cfgs = [(n, '') for n in model_names]
330 |
331 | if not model_cfgs and os.path.isfile(args.model):
332 | with open(args.model) as f:
333 | model_names = [line.rstrip() for line in f]
334 | model_cfgs = [(n, None) for n in model_names if n]
335 |
336 | if len(model_cfgs):
337 | results_file = args.results_file or './results-all.csv'
338 | _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
339 | results = []
340 | try:
341 | initial_batch_size = args.batch_size
342 | for m, c in model_cfgs:
343 | args.model = m
344 | args.checkpoint = c
345 | r = _try_run(args, initial_batch_size)
346 | if 'error' in r:
347 | continue
348 | if args.checkpoint:
349 | r['checkpoint'] = args.checkpoint
350 | results.append(r)
351 | except KeyboardInterrupt as e:
352 | pass
353 | results = sorted(results, key=lambda x: x['top1'], reverse=True)
354 | if len(results):
355 | write_results(results_file, results)
356 | else:
357 | results = validate(args)
358 | # output results in JSON to stdout w/ delimiter for runner script
359 | print(f'--result\n{json.dumps(results, indent=4)}')
360 |
361 |
362 | def write_results(results_file, results):
363 | with open(results_file, mode='w') as cf:
364 | dw = csv.DictWriter(cf, fieldnames=results[0].keys())
365 | dw.writeheader()
366 | for r in results:
367 | dw.writerow(r)
368 | cf.flush()
369 |
370 |
371 | if __name__ == '__main__':
372 | main()
373 |
--------------------------------------------------------------------------------