├── README.md └── Distiller ├── AT.py ├── FitNet.py ├── FRS.py ├── DeFeat.py ├── FKD.py ├── InsDist.py └── FGD.py /README.md: -------------------------------------------------------------------------------- 1 | # InsDist 2 | This is a pytorch implementation for our paper, "Instance-Aware Distillation for Efficeint Object Detection in Remote Sensing Images". 3 | 4 | # Usage 5 | #### 1. Download MMdetection Framework and Dataset 6 | 7 | Please first download [mmdetection](https://github.com/open-mmlab/mmdetection) and DIOR and DOTA datasets and make sure that you can run a baseline model successfully. 8 | 9 | 10 | -------------------------------------------------------------------------------- /Distiller/AT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | 8 | class AT(nn.Module): 9 | def __init__(self): 10 | super(AT, self).__init__() 11 | self.p = 2 12 | 13 | def at(self, f): 14 | return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1)) 15 | 16 | def forward(self, feat_S, feat_T): 17 | for i range(0, len(feat_S)): 18 | 19 | loss += (self.at(feat_S[i]) - self.at(feat_T[i])).pow(2).mean() 20 | 21 | return loss / len(feat_S) -------------------------------------------------------------------------------- /Distiller/FitNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | 8 | class FitNet(nn.Module): 9 | def __init__(self, s_channels, t_channels): 10 | super(FitNet, self).__init__() 11 | 12 | # adaptation 13 | self.conv = nn.Conv2d(s_channels, t_channels, kernel_size=1, bias=False) 14 | 15 | def forward(self, feat_S, feat_T): 16 | B, C, H, W = feat_S[0].size() 17 | 18 | distill_feat_loss = 0 19 | for i in range(len(feat_S)): 20 | 21 | feat_S = self.conv(feat_S) 22 | 23 | distill_feat_loss += ((feat_S - feat_T)**2) 24 | 25 | return distill_feat_loss / len(feat_S) 26 | -------------------------------------------------------------------------------- /Distiller/FRS.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn as nn 6 | import cv2 as cv 7 | import numpy as np 8 | import torch.nn.functional as F 9 | import matplotlib.pyplot as plt 10 | from matplotlib import imagebuild_backbone 11 | import seaborn as sns 12 | import os 13 | import xml.dom.minidom as xml 14 | from mmcv.cnn import constant_init, kaiming_init 15 | 16 | 17 | class FRS(nn.Module): 18 | def __init__(self): 19 | super(FRS, self).__init__() 20 | 21 | def forward(self, x, y, tea_bbox_outs, stu_feature_adap): 22 | 23 | tea_cls_score = tea_bbox_outs[0] 24 | 25 | layers = len(tea_cls_score) 26 | 27 | distill_feat_loss, distill_cls_loss = 0, 0 28 | 29 | for layer in range(layers): 30 | 31 | tea_cls_score_sigmoid = tea_cls_score[layer].sigmoid() 32 | 33 | mask = torch.max(tea_cls_score_sigmoid, dim=1).values 34 | mask = mask.detach() 35 | 36 | feat_loss = torch.pow((y[layer] - stu_feature_adap[layer](x[layer])), 2) 37 | 38 | loss = (feat_loss * mask[:,None,:,:]).sum() 39 | if loss > 1000000: 40 | loss = 0 41 | 42 | distill_feat_loss += loss / mask.sum() 43 | 44 | # print((feat_loss * mask[:,None,:,:]).sum(), mask.sum()) 45 | return distill_feat_loss -------------------------------------------------------------------------------- /Distiller/DeFeat.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn as nn 6 | import cv2 as cv 7 | import numpy as np 8 | import torch.nn.functional as F 9 | import matplotlib.pyplot as plt 10 | from matplotlib import imagebuild_backbone 11 | import seaborn as sns 12 | import os 13 | import xml.dom.minidom as xml 14 | from mmcv.cnn import constant_init, kaiming_init 15 | 16 | 17 | 18 | class DeFeat(nn.Module): 19 | def __init__(self, weight_gt, weight_bg): 20 | super(DeFeat, self).__init__() 21 | self.weight_gt = weight_gt 22 | self.weight_bg = weight_bg 23 | 24 | def gt_mask(self, gt_bboxes, featmap_size, featmap_stride): 25 | 26 | with torch.no_grad(): 27 | mask_batch = [] 28 | for batch in range(len(gt_bboxes)): 29 | h, w = featmap_size[0], featmap_size[1] 30 | mask_per_img = torch.zeros([h, w], dtype=torch.double).cuda() 31 | for ins in range(gt_bboxes[batch].shape[0]): 32 | gt_level_map = gt_bboxes[batch][ins] / featmap_stride 33 | lx = int(gt_level_map[0]) 34 | lx = min(lx, w - 1) 35 | rx = int(gt_level_map[2]) 36 | rx = min(rx, w - 1) 37 | 38 | ly = int(gt_level_map[1]) 39 | ly = min(ly, h - 1) 40 | ry = int(gt_level_map[3]) 41 | ry = min(ry, h - 1) 42 | 43 | if (lx == rx) or (ly == ry): 44 | mask_per_img[ly, lx] += 1 45 | else: 46 | mask_per_img[ly:ry, lx:rx] += 1 47 | 48 | mask_per_img = (mask_per_img > 0).double() 49 | 50 | mask_batch.append(mask_per_img) 51 | 52 | return torch.stack(mask_batch, dim=0) 53 | 54 | def defeat(self, tensor_a, tensor_b, mask): 55 | diff = (tensor_a - tensor_b) ** 2 56 | 57 | mask_gt = mask.unsqueeze(1).repeat(1, tensor_a.size(1), 1, 1).cuda() 58 | diff_gt = diff * mask_gt 59 | diff_gt = (torch.sum(diff_gt) + 1e-8) ** 0.5 60 | 61 | mask_bg = (1 - mask_gt) 62 | diff_bg = diff * mask_bg 63 | diff_bg = (torch.sum(diff_bg) + 1e-8) ** 0.5 64 | 65 | return diff_gt, diff_bg 66 | 67 | def forward(self, feat_s, feat_t, gt_bbox, adaptation_layers): 68 | strides = [8, 16, 32, 64, 128] 69 | feat_loss = 0 70 | 71 | for i in range(0, len(feat_t)): 72 | 73 | featmap_size = feat_t[i].shape[2:] 74 | 75 | mask = self.gt_mask(gt_bboxes, featmap_size, strides[i]) 76 | 77 | loss_gt, loss_bg = self.dist2_mask(feat_t[i], adaptation_layers[i](neck_feat_s[i]), mask) 78 | feat_loss += (self.weight_gt * loss_gt + self.weight_bg * loss_bg) 79 | 80 | return feat_loss -------------------------------------------------------------------------------- /Distiller/FKD.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | 9 | 10 | class FRS(nn.Module): 11 | def __init__(self): 12 | super(FRS, self).__init__() 13 | 14 | 15 | def dist2(self, tensor_a, tensor_b, attention_mask=None, channel_attention_mask=None): 16 | 17 | diff = (tensor_a - tensor_b) ** 2 18 | diff = diff * attention_mask 19 | diff = diff * channel_attention_mask 20 | diff = torch.sum(diff) ** 0.5 21 | 22 | return diff 23 | 24 | def forward(self, t_feats, x, stu_feature_adap): 25 | 26 | t = 0.1 27 | s_ratio = 1.0 28 | kd_feat_loss = 0 29 | kd_channel_loss = 0 30 | kd_spatial_loss = 0 31 | 32 | # for channel attention 33 | c_t = 0.1 34 | c_s_ratio = 1.0 35 | 36 | for _i in range(len(t_feats)): 37 | # spatial-teacher 38 | t_attention_mask = torch.mean(torch.abs(t_feats[_i]), [1], keepdim=True) 39 | size = t_attention_mask.size() 40 | t_attention_mask = t_attention_mask.view(x[0].size(0), -1) 41 | t_attention_mask = torch.softmax(t_attention_mask / t, dim=1) * size[-1] * size[-2] 42 | t_attention_mask = t_attention_mask.view(size) 43 | # spatial-student 44 | s_attention_mask = torch.mean(torch.abs(x[_i]), [1], keepdim=True) 45 | size = s_attention_mask.size() 46 | s_attention_mask = s_attention_mask.view(x[0].size(0), -1) 47 | s_attention_mask = torch.softmax(s_attention_mask / t, dim=1) * size[-1] * size[-2] 48 | s_attention_mask = s_attention_mask.view(size) 49 | # channel-teacher 50 | c_t_attention_mask = torch.mean(torch.abs(t_feats[_i]), [2, 3], keepdim=True) # 2 x 256 x 1 x1 51 | c_size = c_t_attention_mask.size() 52 | c_t_attention_mask = c_t_attention_mask.view(x[0].size(0), -1) # 2 x 256 53 | c_t_attention_mask = torch.softmax(c_t_attention_mask / c_t, dim=1) * 256 54 | c_t_attention_mask = c_t_attention_mask.view(c_size) # 2 x 256 -> 2 x 256 x 1 x 1 55 | # channel-student 56 | c_s_attention_mask = torch.mean(torch.abs(x[_i]), [2, 3], keepdim=True) # 2 x 256 x 1 x1 57 | c_size = c_s_attention_mask.size() 58 | c_s_attention_mask = c_s_attention_mask.view(x[0].size(0), -1) # 2 x 256 59 | c_s_attention_mask = torch.softmax(c_s_attention_mask / c_t, dim=1) * 256 60 | c_s_attention_mask = c_s_attention_mask.view(c_size) # 2 x 256 -> 2 x 256 x 1 x 1 61 | # mask for feature imitation 62 | sum_attention_mask = (t_attention_mask + s_attention_mask * s_ratio) / (1 + s_ratio) 63 | sum_attention_mask = sum_attention_mask.detach() 64 | c_sum_attention_mask = (c_t_attention_mask + c_s_attention_mask * c_s_ratio) / (1 + c_s_ratio) 65 | c_sum_attention_mask = c_sum_attention_mask.detach() 66 | # feature imitation loss 67 | kd_feat_loss += self.dist2(t_feats[_i], stu_feature_adap[_i](x[_i]), attention_mask=sum_attention_mask, channel_attention_mask=c_sum_attention_mask) * 7e-5 * 6 68 | 69 | 70 | return kd_feat_loss 71 | -------------------------------------------------------------------------------- /Distiller/InsDist.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn as nn 6 | import cv2 as cv 7 | import numpy as np 8 | import torch.nn.functional as F 9 | import matplotlib.pyplot as plt 10 | from matplotlib import imagebuild_backbone 11 | import seaborn as sns 12 | import os 13 | import xml.dom.minidom as xml 14 | from mmcv.cnn import constant_init, kaiming_init 15 | 16 | 17 | 18 | class InsDist(nn.Module): 19 | def __init__(self, weight_gt, weight_bg, threshold): 20 | super(InsDist, self).__init__() 21 | self.weight_gt = weight_gt 22 | self.weight_bg = weight_bg 23 | self.threshold = threshold 24 | 25 | 26 | def mask_ours(self, gt_bboxes, backbone_feat, featmap_size, featmap_stride, threshold): 27 | 28 | avgpool = nn.AdaptiveAvgPool2d((1, 1)) 29 | with torch.no_grad(): 30 | mask_batch = [] 31 | for batch in range(len(gt_bboxes)): 32 | 33 | h, w = featmap_size[0], featmap_size[1] 34 | mask_per_img = torch.zeros([h, w], dtype=torch.double).cuda() 35 | 36 | for ins in range(gt_bboxes[batch].shape[0]): 37 | gt_level_map = gt_bboxes[batch][ins] / featmap_stride 38 | 39 | lx = min(max(0, int(gt_level_map[0])), w - 1) 40 | rx = min(max(0, int(gt_level_map[2])), w - 1) 41 | ly = min(max(0, int(gt_level_map[1])), h - 1) 42 | ry = min(max(0, int(gt_level_map[3])), h - 1) 43 | 44 | if (lx == rx) or (ly == ry): 45 | mask_per_img[ly, lx] += 1 46 | else: 47 | x = backbone_feat[batch].view(-1, h * w).permute(1, 0) 48 | feature_gt = avgpool(backbone_feat[batch][:, ly:(ry + 1), lx:(rx + 1)]).squeeze(-1) 49 | energy = torch.mm(x, feature_gt) 50 | 51 | min_ = torch.min(energy) 52 | max_ = torch.max(energy) 53 | assert max_ != 0 54 | energy = (energy - min_) / max_ 55 | attention = energy.view(h, w) 56 | 57 | attention = (attention > threshold).double() 58 | mask_per_img += attention 59 | mask_per_img = (mask_per_img > 0).double() 60 | mask_batch.append(mask_per_img) 61 | 62 | return torch.stack(mask_batch, dim=0) 63 | 64 | def dist(self, tensor_a, tensor_b, mask): 65 | diff = (tensor_a - tensor_b) ** 2 66 | 67 | mask_gt = mask.unsqueeze(1).repeat(1, tensor_a.size(1), 1, 1).cuda() 68 | diff_gt = diff * mask_gt 69 | diff_gt = (torch.sum(diff_gt) + 1e-8) ** 0.5 70 | 71 | mask_bg = (1 - mask_gt) 72 | diff_bg = diff * mask_bg 73 | diff_bg = (torch.sum(diff_bg) + 1e-8) ** 0.5 74 | 75 | return diff_gt, diff_bg 76 | 77 | def forward(self, feat_s, feat_t, gt_bbox, adaptation_layers): 78 | 79 | _mask = self.mask(gt_bbox, feat_t[-1], featmap_size=feat_t[-1].shape[2:], featmap_stride = 32, threshold=self.threshold).unsqueeze(1) 80 | feat_loss = 0 81 | 82 | for i in range(0, len(feat_t)): 83 | d_size = feat_t[i].shape[2:] 84 | mask = F.interpolate(_mask, d_size).squeeze(1) 85 | loss_gt, loss_bg = dist(feat_t[i], adaptation_layers[i](feat_s[i]), mask) 86 | feat_loss += (loss_gt * self.weight_gt + loss_bg * self.weight_bg) 87 | 88 | return feat_loss -------------------------------------------------------------------------------- /Distiller/FGD.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import warnings 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn as nn 8 | import cv2 as cv 9 | import numpy as np 10 | import torch.nn.functional as F 11 | import matplotlib.pyplot as plt 12 | from matplotlib import imagebuild_backbone 13 | import seaborn as sns 14 | import os 15 | import xml.dom.minidom as xml 16 | from mmcv.cnn import constant_init, kaiming_init 17 | 18 | 19 | class fgdloss(nn.Module): 20 | def __init__(self, 21 | student_channels = 256, 22 | teacher_channels =256, 23 | temp=0.5, 24 | alpha_fgd=0.001, 25 | beta_fgd=0.0005, 26 | gamma_fgd=0.001, 27 | lambda_fgd=0.000005, 28 | ): 29 | super(fgdloss, self).__init__() 30 | self.temp = temp 31 | self.alpha_fgd = alpha_fgd 32 | self.beta_fgd = beta_fgd 33 | self.gamma_fgd = gamma_fgd 34 | self.lambda_fgd = lambda_fgd 35 | 36 | if student_channels != teacher_channels: 37 | self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0) 38 | else: 39 | self.align = None 40 | 41 | self.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1).cuda() 42 | self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1).cuda() 43 | 44 | self.channel_add_conv_s = nn.Sequential( 45 | nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1), 46 | nn.LayerNorm([teacher_channels//2, 1, 1]), 47 | nn.ReLU(inplace=True), # yapf: disable 48 | nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1)).cuda() 49 | 50 | self.channel_add_conv_t = nn.Sequential( 51 | nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1), 52 | nn.LayerNorm([teacher_channels//2, 1, 1]), 53 | nn.ReLU(inplace=True), # yapf: disable 54 | nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1)).cuda() 55 | 56 | self.reset_parameters() 57 | 58 | 59 | def forward(self, 60 | preds_S, 61 | preds_T, 62 | gt_bboxes, 63 | img_metas): 64 | assert preds_S.shape[-2:] == preds_T.shape[-2:],'the output dim of teacher and student differ' 65 | 66 | if self.align is not None: 67 | preds_S = self.align(preds_S) 68 | 69 | N,C,H,W = preds_S.shape 70 | 71 | S_attention_t, C_attention_t = self.get_attention(preds_T, self.temp) 72 | S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp) 73 | 74 | Mask_fg = torch.zeros_like(S_attention_t) 75 | Mask_bg = torch.ones_like(S_attention_t) 76 | wmin,wmax,hmin,hmax = [],[],[],[] 77 | for i in range(N): 78 | new_boxxes = torch.ones_like(gt_bboxes[i]) 79 | new_boxxes[:, 0] = gt_bboxes[i][:, 0]/img_metas[i]['img_shape'][1]*W 80 | new_boxxes[:, 2] = gt_bboxes[i][:, 2]/img_metas[i]['img_shape'][1]*W 81 | new_boxxes[:, 1] = gt_bboxes[i][:, 1]/img_metas[i]['img_shape'][0]*H 82 | new_boxxes[:, 3] = gt_bboxes[i][:, 3]/img_metas[i]['img_shape'][0]*H 83 | 84 | wmin.append(torch.floor(new_boxxes[:, 0]).int()) 85 | wmax.append(torch.ceil(new_boxxes[:, 2]).int()) 86 | hmin.append(torch.floor(new_boxxes[:, 1]).int()) 87 | hmax.append(torch.ceil(new_boxxes[:, 3]).int()) 88 | 89 | area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1)) 90 | 91 | for j in range(len(gt_bboxes[i])): 92 | Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \ 93 | torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j]) 94 | 95 | Mask_bg[i] = torch.where(Mask_fg[i]>0, 0, 1) 96 | if torch.sum(Mask_bg[i]): 97 | Mask_bg[i] /= torch.sum(Mask_bg[i]) 98 | 99 | fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg, 100 | C_attention_s, C_attention_t, S_attention_s, S_attention_t) 101 | mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t) 102 | rela_loss = self.get_rela_loss(preds_S, preds_T) 103 | 104 | # loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss 105 | loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss 106 | 107 | return loss 108 | 109 | 110 | def get_attention(self, preds, temp): 111 | """ preds: Bs*C*W*H """ 112 | N, C, H, W= preds.shape 113 | 114 | value = torch.abs(preds) 115 | # Bs*W*H 116 | fea_map = value.mean(axis=1, keepdim=True) 117 | S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W) 118 | 119 | # Bs*C 120 | channel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False) 121 | C_attention = C * F.softmax(channel_map/temp, dim=1) 122 | 123 | return S_attention, C_attention 124 | 125 | 126 | def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t): 127 | loss_mse = nn.MSELoss(reduction='sum') 128 | 129 | Mask_fg = Mask_fg.unsqueeze(dim=1) 130 | Mask_bg = Mask_bg.unsqueeze(dim=1) 131 | 132 | C_t = C_t.unsqueeze(dim=-1) 133 | C_t = C_t.unsqueeze(dim=-1) 134 | 135 | S_t = S_t.unsqueeze(dim=1) 136 | 137 | fea_t= torch.mul(preds_T, torch.sqrt(S_t)) 138 | fea_t = torch.mul(fea_t, torch.sqrt(C_t)) 139 | fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg)) 140 | bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg)) 141 | 142 | fea_s = torch.mul(preds_S, torch.sqrt(S_t)) 143 | fea_s = torch.mul(fea_s, torch.sqrt(C_t)) 144 | fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg)) 145 | bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg)) 146 | 147 | fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg) 148 | bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg) 149 | 150 | return fg_loss, bg_loss 151 | 152 | 153 | def get_mask_loss(self, C_s, C_t, S_s, S_t): 154 | 155 | mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s) 156 | 157 | return mask_loss 158 | 159 | 160 | def spatial_pool(self, x, in_type): 161 | batch, channel, width, height = x.size() 162 | input_x = x 163 | # [N, C, H * W] 164 | input_x = input_x.view(batch, channel, height * width) 165 | # [N, 1, C, H * W] 166 | input_x = input_x.unsqueeze(1) 167 | # [N, 1, H, W] 168 | if in_type == 0: 169 | context_mask = self.conv_mask_s(x) 170 | else: 171 | context_mask = self.conv_mask_t(x) 172 | # [N, 1, H * W] 173 | context_mask = context_mask.view(batch, 1, height * width) 174 | # [N, 1, H * W] 175 | context_mask = F.softmax(context_mask, dim=2) 176 | # [N, 1, H * W, 1] 177 | context_mask = context_mask.unsqueeze(-1) 178 | # [N, 1, C, 1] 179 | context = torch.matmul(input_x, context_mask) 180 | # [N, C, 1, 1] 181 | context = context.view(batch, channel, 1, 1) 182 | 183 | return context 184 | 185 | 186 | def get_rela_loss(self, preds_S, preds_T): 187 | loss_mse = nn.MSELoss(reduction='sum') 188 | 189 | context_s = self.spatial_pool(preds_S, 0) 190 | context_t = self.spatial_pool(preds_T, 1) 191 | 192 | out_s = preds_S 193 | out_t = preds_T 194 | 195 | channel_add_s = self.channel_add_conv_s(context_s) 196 | out_s = out_s + channel_add_s 197 | 198 | channel_add_t = self.channel_add_conv_t(context_t) 199 | out_t = out_t + channel_add_t 200 | 201 | rela_loss = loss_mse(out_s, out_t)/len(out_s) 202 | 203 | return rela_loss 204 | 205 | 206 | def last_zero_init(self, m): 207 | if isinstance(m, nn.Sequential): 208 | constant_init(m[-1], val=0) 209 | else: 210 | constant_init(m, val=0) 211 | 212 | 213 | def reset_parameters(self): 214 | kaiming_init(self.conv_mask_s, mode='fan_in') 215 | kaiming_init(self.conv_mask_t, mode='fan_in') 216 | self.conv_mask_s.inited = True 217 | self.conv_mask_t.inited = True 218 | 219 | self.last_zero_init(self.channel_add_conv_s) 220 | self.last_zero_init(self.channel_add_conv_t) --------------------------------------------------------------------------------