├── Datasets.py ├── LossFunctions.py ├── MeNets.py ├── MeNets_NAS.py ├── Metrics.py ├── ModelTrain_Final.py ├── PrepareData_LOSO_CD.py ├── RCNs.py └── README.md /Datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import matplotlib.pyplot as plt 6 | 7 | class MEGC2019(torch.utils.data.Dataset): 8 | """MEGC2019 dataset class with 3 categories""" 9 | 10 | def __init__(self, imgList, transform=None): 11 | self.imgPath = [] 12 | self.label = [] 13 | self.dbtype = [] 14 | with open(imgList,'r') as f: 15 | for textline in f: 16 | texts= textline.strip('\n').split(' ') 17 | self.imgPath.append(texts[0]) 18 | self.label.append(int(texts[1])) 19 | self.dbtype.append(int(texts[2])) 20 | self.transform = transform 21 | 22 | def __getitem__(self, idx): 23 | img = Image.open("".join(self.imgPath[idx]),'r').convert('RGB') 24 | # plt.imshow(img) 25 | # plt.show() 26 | if self.transform is not None: 27 | img = self.transform(img) 28 | return img, self.label[idx] 29 | 30 | def __len__(self): 31 | return len(self.imgPath) 32 | 33 | class MEGC2019_SI(torch.utils.data.Dataset): 34 | """MEGC2019_SI dataset class with 3 categories and other side information""" 35 | 36 | def __init__(self, imgList, transform=None): 37 | self.imgPath = [] 38 | self.label = [] 39 | self.dbtype = [] 40 | with open(imgList,'r') as f: 41 | for textline in f: 42 | texts= textline.strip('\n').split(' ') 43 | self.imgPath.append(texts[0]) 44 | self.label.append(int(texts[1])) 45 | self.dbtype.append(int(texts[2])) 46 | self.transform = transform 47 | 48 | def __getitem__(self, idx): 49 | img = Image.open("".join(self.imgPath[idx]),'r').convert('RGB') 50 | # plt.imshow(img) 51 | # plt.show() 52 | if self.transform is not None: 53 | img = self.transform(img) 54 | return {"data":img, "class_label":self.label[idx], 'db_label':self.dbtype[idx]} 55 | 56 | def __len__(self): 57 | return len(self.imgPath) 58 | 59 | class MEGC2019_FOLDER(torch.utils.data.Dataset): 60 | """MEGC2019 dataset class with 3 categories, organized in folders""" 61 | 62 | def __init__(self, rootDir, transform=None): 63 | labels = os.listdir(rootDir) 64 | labels.sort() 65 | self.fileList = [] 66 | self.label = [] 67 | self.imgPath = [] 68 | for subfolder in labels: 69 | label = [] 70 | imgPath = [] 71 | files = os.listdir(os.path.join(rootDir, subfolder)) 72 | files.sort() 73 | self.fileList.extend(files) 74 | label = [int(subfolder) for file in files] 75 | imgPath = [os.path.join(rootDir, subfolder,file) for file in files] 76 | self.label.extend(label) 77 | self.imgPath.extend(imgPath) 78 | self.transform = transform 79 | 80 | def __getitem__(self, idx): 81 | img = Image.open(self.imgPath[idx],'r').convert('RGB') 82 | # plt.imshow(img) 83 | # plt.show() 84 | if self.transform is not None: 85 | img = self.transform(img) 86 | return {"data":img, "class_label":self.label[idx]} 87 | 88 | def __len__(self): 89 | return len(self.fileList) -------------------------------------------------------------------------------- /LossFunctions.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules import Module 2 | import torch.nn._reduction as _Reduction 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import torch 7 | 8 | class _Loss(Module): 9 | def __init__(self, size_average=None, reduce=None, reduction='mean'): 10 | super(_Loss, self).__init__() 11 | if size_average is not None or reduce is not None: 12 | self.reduction = _Reduction.legacy_get_string(size_average, reduce) 13 | else: 14 | self.reduction = reduction 15 | 16 | class HHLoss(_Loss): 17 | r"""Creates a criterion that measures the hierarchy label error (squared L2 norm) between 18 | each element in the input :math:`x` and target :math:`y`. 19 | """ 20 | def __init__(self, size_average=None, reduce=None, reduction='mean'): 21 | super(HHLoss, self).__init__(size_average, reduce, reduction) 22 | self.norm_p = 2 23 | self.param_lambda = 0.001 24 | self.param_beta = 0.002 25 | 26 | def forward(self, input, target): 27 | input_s = input.mm(input.t()) 28 | target_s = target.mm(target.t()) 29 | uncorr_M = input_s.mm(input_s.t())/input_s.shape[0] 30 | I = torch.eye(input_s.shape[1]).type_as(uncorr_M) 31 | loss = torch.norm(input_s-target_s,p=self.norm_p)\ 32 | + self.param_lambda*torch.norm(torch.sum(input_s,0),p=self.norm_p)\ 33 | + self.param_beta*torch.norm(uncorr_M-I,p=self.norm_p) 34 | return loss 35 | # input_s = input.mm(input.t()) 36 | # target_s = target.mm(target.t()) 37 | # return F.mse_loss(input_s, target_s, reduction=self.reduction) 38 | 39 | class HHLoss_bin(_Loss): 40 | r"""Creates a criterion that measures the hierarchy label error (squared L2 norm) between 41 | each element in the input :math:`x` and target :math:`y`. 42 | """ 43 | def __init__(self, size_average=None, reduce=None, reduction='mean'): 44 | super(HHLoss_bin, self).__init__(size_average, reduce, reduction) 45 | self.th = 0.0 46 | self.norm_p = 2 47 | self.param_lambda = 0.001 48 | self.param_beta = 0.002 49 | self.mu = 0.001 50 | 51 | def forward(self, input, target): 52 | input_b = torch.sign(input) 53 | input_s = input.mm(input.t()) 54 | # input_s = input_b.mm(input_b.t()) 55 | target_s = target.mm(target.t()) 56 | uncorr_M = input_s.mm(input_s.t()) / input_s.shape[0] 57 | I = torch.eye(input_s.shape[1]).type_as(uncorr_M) 58 | loss = torch.norm(input_s - target_s, p=self.norm_p) \ 59 | + self.param_lambda * torch.norm(torch.sum(input, 0), p=self.norm_p) \ 60 | + self.param_beta * torch.norm(uncorr_M - I, p=self.norm_p) \ 61 | + self.mu*torch.norm(torch.sum(input_b-input, 0), p=self.norm_p) 62 | return loss 63 | # input = (torch.sign(input)+1)/2 64 | # input = torch.sign(input) 65 | # input_s = input.mm(input.t()) 66 | # target_s = target.mm(target.t()) 67 | # uncorr_M = input_s.mm(input_s.t()) / input_s.shape[0] 68 | # I = torch.eye(input_s.shape[1]).type_as(uncorr_M) 69 | # loss = torch.norm(input_s - target_s, p=self.norm_p) \ 70 | # + self.param_lambda * torch.norm(torch.sum(input, 0), p=self.norm_p) \ 71 | # + self.param_beta * torch.norm(uncorr_M - I, p=self.norm_p) \ 72 | # + self.mu*torch.norm(torch.sum(input, 0) 73 | # return loss 74 | 75 | class MSELoss(_Loss): 76 | r"""Creates a criterion that measures the mean squared error (squared L2 norm) between 77 | each element in the input :math:`x` and target :math:`y`. 78 | """ 79 | __constants__ = ['reduction'] 80 | 81 | def __init__(self, size_average=None, reduce=None, reduction='mean'): 82 | super(MSELoss, self).__init__(size_average, reduce, reduction) 83 | 84 | def forward(self, input, target): 85 | return F.mse_loss(input, target, reduction=self.reduction) 86 | 87 | class FocalLoss(nn.Module): 88 | r""" 89 | This criterion is a implemenation of Focal Loss, which is proposed in 90 | Focal Loss for Dense Object Detection. 91 | 92 | Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) 93 | 94 | The losses are averaged across observations for each minibatch. 95 | 96 | Args: 97 | alpha(1D Tensor, Variable) : the scalar factor for this criterion 98 | gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 99 | putting more focus on hard, misclassified examples 100 | size_average(bool): By default, the losses are averaged over observations for each minibatch. 101 | However, if the field size_average is set to False, the losses are 102 | instead summed for each minibatch. 103 | 104 | 105 | """ 106 | def __init__(self, class_num, alpha=None, gamma=2, size_average=True, device='cpu'): 107 | super(FocalLoss, self).__init__() 108 | if alpha is None: 109 | self.alpha = Variable(torch.ones(class_num, 1)) 110 | else: 111 | if isinstance(alpha, Variable): 112 | self.alpha = alpha 113 | else: 114 | self.alpha = Variable(alpha) 115 | self.gamma = gamma 116 | self.class_num = class_num 117 | self.size_average = size_average 118 | self.device = device 119 | 120 | def forward(self, inputs, targets): 121 | N = inputs.size(0) 122 | C = inputs.size(1) 123 | P = F.softmax(inputs,dim=1) 124 | 125 | class_mask = inputs.data.new(N, C).fill_(0) 126 | class_mask = Variable(class_mask) 127 | ids = targets.view(-1, 1) 128 | class_mask.scatter_(1, ids.data, 1.) 129 | 130 | if inputs.is_cuda and not self.alpha.is_cuda: 131 | self.alpha = self.alpha.to(self.device) 132 | alpha = self.alpha[ids.data.view(-1)] 133 | probs = (P*class_mask).sum(1).view(-1,1) 134 | log_p = probs.log() 135 | batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 136 | if self.size_average: 137 | loss = batch_loss.mean() 138 | else: 139 | loss = batch_loss.sum() 140 | return loss 141 | 142 | class BalancedLoss(nn.Module): 143 | r""" 144 | This criterion is a implemenation of Focal Loss, which is proposed in 145 | Focal Loss for Dense Object Detection. 146 | 147 | Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) 148 | 149 | The losses are averaged across observations for each minibatch. 150 | 151 | Args: 152 | alpha(1D Tensor, Variable) : the scalar factor for this criterion 153 | gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 154 | putting more focus on hard, misclassified examples 155 | size_average(bool): By default, the losses are averaged over observations for each minibatch. 156 | However, if the field size_average is set to False, the losses are 157 | instead summed for each minibatch. 158 | 159 | 160 | """ 161 | def __init__(self, class_num, alpha=None, gamma=2, size_average=True, device='cpu'): 162 | super(BalancedLoss, self).__init__() 163 | if alpha is None: 164 | self.alpha = Variable(torch.ones(class_num, 1)) 165 | else: 166 | if isinstance(alpha, Variable): 167 | self.alpha = alpha 168 | else: 169 | self.alpha = Variable(alpha) 170 | self.gamma = gamma 171 | self.class_num = class_num 172 | self.size_average = size_average 173 | self.device = device 174 | 175 | def forward(self, inputs, targets): 176 | N = inputs.size(0) 177 | C = inputs.size(1) 178 | P = F.softmax(inputs,dim=1) 179 | 180 | class_mask = inputs.data.new(N, C).fill_(0) 181 | class_mask = Variable(class_mask) 182 | ids = targets.view(-1, 1) 183 | class_mask.scatter_(1, ids.data, 1.) 184 | 185 | self.alpha = torch.histc(ids, bins=self.class_num, min=0, max=self.class_num-1).float()/float(ids.shape[0]) 186 | # self.alpha = 1.0*self.alpha.reciprocal() # 10, 100 187 | self.alpha = 1.0 - self.alpha/10.0 188 | alpha_c = self.alpha[ids.data.view(-1)] 189 | if inputs.is_cuda and not alpha_c.is_cuda: 190 | alpha_c = alpha_c.to(self.device) 191 | probs = (P*class_mask).sum(1).view(-1,1) 192 | log_p = probs.log() 193 | batch_loss = -alpha_c*(torch.pow((1-probs), self.gamma))*log_p 194 | if self.size_average: 195 | loss = batch_loss.mean() 196 | else: 197 | loss = batch_loss.sum() 198 | return loss 199 | 200 | class CosineLoss(nn.Module): 201 | r""" 202 | This criterion is a implemenation of Cosine Loss 203 | """ 204 | def __init__(self, size_average=True): 205 | super(CosineLoss, self).__init__() 206 | 207 | self.size_average = size_average 208 | 209 | def forward(self, inputs, targets): 210 | N = inputs.size(0) 211 | C = inputs.size(1) 212 | P = torch.div(inputs,inputs.norm(dim=1,keepdim=True)) 213 | # one-hot coding 214 | class_mask = inputs.data.new(N, C).fill_(0) 215 | class_mask = Variable(class_mask) 216 | ids = targets.view(-1, 1) 217 | class_mask.scatter_(1, ids.data, 1.) 218 | 219 | probs = (P*class_mask).sum(1).view(-1,1) 220 | log_p = 1.0-probs 221 | 222 | if self.size_average: 223 | loss = log_p.mean() 224 | else: 225 | loss = log_p.sum() 226 | return loss -------------------------------------------------------------------------------- /MeNets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | __all__ = ['RCN_A', 'RCN_S', 'RCN_W', 'RCN_P', 'RCN_C', 'RCN_F'] 6 | 7 | class ConvBlock(nn.Module): 8 | """convolutional layer blocks for sequtial convolution operations""" 9 | def __init__(self, in_features, out_features, num_conv, pool=False): 10 | super(ConvBlock, self).__init__() 11 | features = [in_features] + [out_features for i in range(num_conv)] 12 | layers = [] 13 | for i in range(len(features)-1): 14 | layers.append(nn.Conv2d(in_channels=features[i], out_channels=features[i+1], kernel_size=3, padding=1, bias=True)) 15 | layers.append(nn.BatchNorm2d(num_features=features[i+1], affine=True, track_running_stats=True)) 16 | layers.append(nn.ReLU()) 17 | if pool: 18 | layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0)) 19 | self.op = nn.Sequential(*layers) 20 | def forward(self, x): 21 | return self.op(x) 22 | 23 | class RclBlock(nn.Module): 24 | """recurrent convolutional blocks""" 25 | def __init__(self, inplanes, planes): 26 | super(RclBlock, self).__init__() 27 | self.ffconv = nn.Sequential( 28 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0), 29 | nn.BatchNorm2d(planes), 30 | nn.ReLU(inplace=True) 31 | ) 32 | self.rrconv = nn.Sequential( 33 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1), 34 | nn.BatchNorm2d(planes), 35 | nn.ReLU(inplace=True) 36 | ) 37 | self.downsample = nn.Sequential( 38 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 39 | nn.Dropout() 40 | ) 41 | 42 | def forward(self, x): 43 | y = self.ffconv(x) 44 | y = self.rrconv(x + y) 45 | y = self.rrconv(x + y) 46 | out = self.downsample (y) 47 | return out 48 | 49 | class DenseBlock(nn.Module): 50 | """densely connected convolutional blocks""" 51 | def __init__(self, inplanes, planes): 52 | super(DenseBlock, self).__init__() 53 | self.conv1 = nn.Sequential( 54 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0), 55 | nn.BatchNorm2d(planes), 56 | nn.ReLU(inplace=True) 57 | ) 58 | self.conv2 = nn.Sequential( 59 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1), 60 | nn.BatchNorm2d(planes), 61 | nn.ReLU(inplace=True) 62 | ) 63 | self.conv3 = nn.Sequential( 64 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1), 65 | nn.BatchNorm2d(planes), 66 | nn.ReLU(inplace=True) 67 | ) 68 | self.downsample = nn.Sequential( 69 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 70 | nn.Dropout() 71 | ) 72 | 73 | def forward(self, x): 74 | y = self.conv1(x) 75 | z = self.conv2(x + y) 76 | # out = self.conv2(x + y + z) 77 | e = self.conv2(x + y + z) 78 | out = self.conv2(x + y + z + e) 79 | out = self.downsample (out) 80 | return out 81 | 82 | class EmbeddingBlock(nn.Module): 83 | """densely connected convolutional blocks for embedding""" 84 | def __init__(self, inplanes, planes): 85 | super(EmbeddingBlock, self).__init__() 86 | self.conv1 = nn.Sequential( 87 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0), 88 | nn.BatchNorm2d(planes), 89 | nn.ReLU(inplace=True) 90 | ) 91 | self.conv2 = nn.Sequential( 92 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1), 93 | nn.BatchNorm2d(planes), 94 | nn.ReLU(inplace=True) 95 | ) 96 | self.attenmap = SpatialAttentionBlock_P(normalize_attn=True) 97 | self.downsample = nn.Sequential( 98 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 99 | nn.Dropout() 100 | ) 101 | 102 | def forward(self, x, w, pool_size, classes): 103 | y = self.conv1(x) 104 | y1 = self.attenmap(F.adaptive_avg_pool2d(x, (pool_size, pool_size)), w, classes) 105 | y = torch.mul(F.interpolate(y1, (y.shape[2], y.shape[3])), y) 106 | z = self.conv2(x+y) 107 | e = self.conv2(x + y + z) 108 | out = self.conv2(x + y + z + e) 109 | out = self.downsample (out) 110 | return out 111 | 112 | class EmbeddingBlock_M(nn.Module): 113 | """densely connected convolutional blocks for embedding with multiple attentions""" 114 | def __init__(self, inplanes, planes): 115 | super(EmbeddingBlock_M, self).__init__() 116 | self.conv1 = nn.Sequential( 117 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0), 118 | nn.BatchNorm2d(planes), 119 | nn.ReLU(inplace=True) 120 | ) 121 | self.conv2 = nn.Sequential( 122 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1), 123 | nn.BatchNorm2d(planes), 124 | nn.ReLU(inplace=True) 125 | ) 126 | self.attenmap = SpatialAttentionBlock_P(normalize_attn=True) 127 | self.downsample = nn.Sequential( 128 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 129 | nn.Dropout() 130 | ) 131 | 132 | def forward(self, x, w, pool_size, classes): 133 | y = self.conv1(x) 134 | z = self.conv2(x + y) 135 | y1 = self.attenmap(F.adaptive_avg_pool2d(x, (pool_size, pool_size)), w, classes) 136 | z = torch.mul(F.interpolate(y1, (z.shape[2], z.shape[3])), z) 137 | e = self.conv2(x + y + z) 138 | out = self.conv2(x + y + z + e) 139 | y2 = self.attenmap(F.adaptive_avg_pool2d(z, (pool_size, pool_size)), w, classes) 140 | out = torch.mul(F.interpolate(y2, (out.shape[2], out.shape[3])), out) 141 | out = self.downsample (out) 142 | return out 143 | 144 | class SpatialAttentionBlock_A(nn.Module): 145 | """linear attention block for any layers""" 146 | def __init__(self, in_features, normalize_attn=True): 147 | super(SpatialAttentionBlock_A, self).__init__() 148 | self.normalize_attn = normalize_attn 149 | self.op = nn.Conv2d(in_channels=in_features, out_channels=1, kernel_size=1, padding=0, bias=False) 150 | 151 | def forward(self, l): 152 | N, C, W, H = l.size() 153 | c = self.op(l) # batch_sizex1xWxH 154 | if self.normalize_attn: 155 | a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,W,H) 156 | else: 157 | a = torch.sigmoid(c) 158 | g = torch.mul(a.expand_as(l), l) 159 | return g 160 | 161 | class SpatialAttentionBlock_P(nn.Module): 162 | """linear attention block for any layers""" 163 | def __init__(self, normalize_attn=True): 164 | super(SpatialAttentionBlock_P, self).__init__() 165 | self.normalize_attn = normalize_attn 166 | 167 | def forward(self, l, w, classes): 168 | output_cam = [] 169 | for idx in range(0,classes): 170 | weights = w[idx,:].reshape((l.shape[1], l.shape[2], l.shape[3])) 171 | cam = weights * l 172 | cam = cam.mean(dim=1,keepdim=True) 173 | cam = cam - torch.min(torch.min(cam,3,True)[0],2,True)[0] 174 | cam = cam / torch.max(torch.max(cam,3,True)[0],2,True)[0] 175 | output_cam.append(cam) 176 | output = torch.cat(output_cam, dim=1) 177 | output = output.mean(dim=1,keepdim=True) 178 | return output 179 | 180 | class SpatialAttentionBlock_F(nn.Module): 181 | """linear attention block for first layer""" 182 | def __init__(self, normalize_attn=True): 183 | super(SpatialAttentionBlock_F, self).__init__() 184 | self.normalize_attn = normalize_attn 185 | 186 | def forward(self, l, w, classes): 187 | output_cam = [] 188 | for idx in range(0,classes): 189 | weights = w[idx,:].reshape((-1, l.shape[2], l.shape[3])) 190 | weights = weights.mean(dim=0,keepdim=True) 191 | cam = weights * l 192 | cam = cam.mean(dim=1,keepdim=True) 193 | cam = cam - torch.min(torch.min(cam,3,True)[0],2,True)[0] 194 | cam = cam / torch.max(torch.max(cam,3,True)[0],2,True)[0] 195 | output_cam.append(cam) 196 | output = torch.cat(output_cam, dim=1) 197 | output = output.mean(dim=1,keepdim=True) 198 | return output 199 | 200 | def MakeLayer(block, planes, blocks): 201 | layers = [] 202 | for _ in range(0, blocks): 203 | layers.append(block(planes, planes)) 204 | return nn.Sequential(*layers) 205 | 206 | class RCN_A(nn.Module): 207 | """menet networks with adding attention unit 208 | """ 209 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5, model_version=3): 210 | super(RCN_A, self).__init__() 211 | self.version = model_version 212 | self.classes = num_classes 213 | self.conv1 = nn.Sequential( 214 | nn.Conv2d(num_input, featuremaps, kernel_size=5, stride=1, padding=0), 215 | nn.BatchNorm2d(featuremaps), 216 | nn.ReLU(inplace=True), 217 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 218 | nn.Dropout(), 219 | ) 220 | self.rcls = MakeLayer(RclBlock, featuremaps, num_layers) 221 | self.attenmap = SpatialAttentionBlock_P(normalize_attn=True) 222 | self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 223 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 224 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 230 | nn.init.constant_(m.weight, 1) 231 | nn.init.constant_(m.bias, 0) 232 | elif isinstance(m, nn.Linear): 233 | nn.init.normal_(m.weight, 0, 0.01) 234 | nn.init.constant_(m.bias, 0) 235 | 236 | def forward(self, x): 237 | if self.version == 1: 238 | x = self.conv1(x) 239 | x = self.attenmap(x) 240 | x = self.rcls(x) 241 | x = self.avgpool(x) 242 | if self.version == 2: 243 | x = self.conv1(x) 244 | x = self.attenmap(x) 245 | x = self.rcls(x) 246 | x = self.avgpool(x) 247 | elif self.version == 3: 248 | x = self.conv1(x) 249 | y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 250 | x = self.rcls(x) 251 | x = self.avgpool(x) 252 | x = x * y 253 | x = torch.flatten(x, 1) 254 | x = self.classifier(x) 255 | return x 256 | 257 | class RCN_S(nn.Module): 258 | """menet networks with dense shortcut connection 259 | """ 260 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 261 | super(RCN_S, self).__init__() 262 | self.conv1 = nn.Sequential( 263 | nn.Conv2d(num_input, featuremaps, kernel_size=5, stride=1, padding=0), 264 | nn.BatchNorm2d(featuremaps), 265 | nn.ReLU(inplace=True), 266 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 267 | nn.Dropout(), 268 | ) 269 | self.dbl = MakeLayer(DenseBlock, featuremaps, num_layers) 270 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 271 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 272 | for m in self.modules(): 273 | if isinstance(m, nn.Conv2d): 274 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 275 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 276 | nn.init.constant_(m.weight, 1) 277 | nn.init.constant_(m.bias, 0) 278 | elif isinstance(m, nn.Linear): 279 | nn.init.normal_(m.weight, 0, 0.01) 280 | nn.init.constant_(m.bias, 0) 281 | 282 | def forward(self, x): 283 | x = self.conv1(x) 284 | x = self.dbl(x) 285 | x = self.avgpool(x) 286 | x = torch.flatten(x, 1) 287 | x = self.classifier(x) 288 | return x 289 | 290 | class RCN_W(nn.Module): 291 | """menet networks with wide expansion 292 | """ 293 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 294 | super(RCN_W, self).__init__() 295 | num_channels = int(featuremaps/2) 296 | self.stream1 = nn.Sequential( 297 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), 298 | nn.ReLU(inplace=True), 299 | nn.BatchNorm2d(num_channels), 300 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 301 | nn.Dropout(), 302 | ) 303 | self.stream2 = nn.Sequential( 304 | # nn.Conv2d(num_input, num_channels, kernel_size=5, stride=3, padding=2), 305 | nn.Conv2d(num_input, int(num_channels/2), kernel_size=3, stride=3, padding=2, dilation=2), # 5,2/ 1,0 306 | nn.ReLU(inplace=True), 307 | nn.BatchNorm2d(int(num_channels/2)), 308 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 309 | nn.Dropout(), 310 | ) 311 | self.stream3 = nn.Sequential( 312 | # nn.Conv2d(num_input, num_channels, kernel_size=5, stride=3, padding=2), 313 | nn.Conv2d(num_input, int(num_channels/2), kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0 314 | nn.ReLU(inplace=True), 315 | nn.BatchNorm2d(int(num_channels/2)), 316 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 317 | nn.Dropout(), 318 | ) 319 | self.rcls = MakeLayer(RclBlock, featuremaps, num_layers) 320 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 321 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 322 | for m in self.modules(): 323 | if isinstance(m, nn.Conv2d): 324 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 325 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 326 | nn.init.constant_(m.weight, 1) 327 | nn.init.constant_(m.bias, 0) 328 | elif isinstance(m, nn.Linear): 329 | nn.init.normal_(m.weight, 0, 0.01) 330 | nn.init.constant_(m.bias, 0) 331 | 332 | def forward(self, x): 333 | x1 = self.stream1(x) 334 | x2 = self.stream2(x) 335 | x3 = self.stream3(x) 336 | x = torch.cat((x1,x2,x3),1) 337 | x = self.rcls(x) 338 | x = self.avgpool(x) 339 | x = torch.flatten(x, 1) 340 | x = self.classifier(x) 341 | return x 342 | 343 | class RCN_P(nn.Module): 344 | """menet networks with hybrid modules by NAS 345 | """ 346 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 347 | super(RCN_H, self).__init__() 348 | self.classes = num_classes 349 | num_channels = int(featuremaps/2) 350 | self.stream1 = nn.Sequential( 351 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3 352 | nn.ReLU(inplace=True), 353 | nn.BatchNorm2d(num_channels), 354 | nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 355 | nn.Dropout(), 356 | ) 357 | self.stream2 = nn.Sequential( 358 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=3, dilation=3), # 5,2/ 1,0 359 | nn.ReLU(inplace=True), 360 | nn.BatchNorm2d(num_channels), 361 | nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 362 | nn.Dropout(), 363 | ) 364 | self.dbl = MakeLayer(DenseBlock, featuremaps, num_layers) 365 | self.rcls = MakeLayer(RclBlock, featuremaps, num_layers) 366 | self.attenmap = SpatialAttentionBlock_P(normalize_attn=True) 367 | self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 368 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 369 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 370 | for m in self.modules(): 371 | if isinstance(m, nn.Conv2d): 372 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 373 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 374 | nn.init.constant_(m.weight, 1) 375 | nn.init.constant_(m.bias, 0) 376 | elif isinstance(m, nn.Linear): 377 | nn.init.normal_(m.weight, 0, 0.01) 378 | nn.init.constant_(m.bias, 0) 379 | 380 | def forward(self, x): 381 | x1 = self.stream1(x) 382 | x2 = self.stream2(x) 383 | x = torch.cat((x1,x2),1) 384 | y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 385 | x = self.dbl(x) 386 | x = self.avgpool(x) 387 | x = x * y 388 | x = torch.flatten(x, 1) 389 | x = self.classifier(x) 390 | return x 391 | 392 | class RCN_C(nn.Module): 393 | """menet networks with cascaded modules 394 | """ 395 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 396 | super(RCN_C, self).__init__() 397 | self.classes = num_classes 398 | self.poolsize = pool_size 399 | num_channels = int(featuremaps/2) 400 | # self.stream1 = nn.Sequential( 401 | # nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3 402 | # nn.ReLU(inplace=True), 403 | # nn.BatchNorm2d(num_channels), 404 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 405 | # nn.Dropout(), 406 | # ) 407 | # self.stream2 = nn.Sequential( 408 | # nn.Conv2d(num_input, num_channels, kernel_size=5, stride=1, padding=2), # 5,2/ 1,0 409 | # nn.ReLU(inplace=True), 410 | # nn.BatchNorm2d(num_channels), 411 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 412 | # nn.Dropout(), 413 | # ) 414 | self.stream1 = nn.Sequential( 415 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), 416 | nn.ReLU(inplace=True), 417 | nn.BatchNorm2d(num_channels), 418 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 419 | nn.Dropout(), 420 | ) 421 | self.stream2 = nn.Sequential( 422 | nn.Conv2d(num_input, int(num_channels/2), kernel_size=3, stride=3, padding=2, dilation=2), # 5,2/ 1,0 423 | nn.ReLU(inplace=True), 424 | nn.BatchNorm2d(int(num_channels/2)), 425 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 426 | nn.Dropout(), 427 | ) 428 | self.stream3 = nn.Sequential( 429 | nn.Conv2d(num_input, int(num_channels/2), kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0 430 | nn.ReLU(inplace=True), 431 | nn.BatchNorm2d(int(num_channels/2)), 432 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 433 | nn.Dropout(), 434 | ) 435 | self.dbl = MakeLayer(DenseBlock, featuremaps, num_layers) 436 | self.ebl = EmbeddingBlock(featuremaps, featuremaps) 437 | # self.attenmap = SpatialAttentionBlock_P(normalize_attn=True) 438 | self.attenmap = SpatialAttentionBlock_F(normalize_attn=True) 439 | self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 440 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 441 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 442 | for m in self.modules(): 443 | if isinstance(m, nn.Conv2d): 444 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 445 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 446 | nn.init.constant_(m.weight, 1) 447 | nn.init.constant_(m.bias, 0) 448 | elif isinstance(m, nn.Linear): 449 | nn.init.normal_(m.weight, 0, 0.01) 450 | nn.init.constant_(m.bias, 0) 451 | 452 | def forward(self, x): 453 | y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 454 | x1 = self.stream1(x) 455 | x2 = self.stream2(x) 456 | x3 = self.stream3(x) 457 | x = torch.cat((x1, x2, x3), 1) 458 | # x = torch.cat((x1,x2),1) 459 | # y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 460 | x = torch.mul(F.interpolate(y,(x.shape[2],x.shape[3])), x) 461 | x = self.dbl(x) 462 | # x = self.ebl(x, self.classifier.weight, self.poolsize, self.classes) 463 | x = self.avgpool(x) 464 | x = torch.flatten(x, 1) 465 | x = self.classifier(x) 466 | return x 467 | 468 | class RCN_F(nn.Module): 469 | """menet networks with embedded modules as final fusion way 470 | """ 471 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 472 | super(RCN_E, self).__init__() 473 | self.classes = num_classes 474 | self.poolsize = pool_size 475 | num_channels = int(featuremaps/2) 476 | # self.stream1 = nn.Sequential( 477 | # nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3 478 | # nn.ReLU(inplace=True), 479 | # nn.BatchNorm2d(num_channels), 480 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 481 | # nn.Dropout(), 482 | # ) 483 | # self.stream2 = nn.Sequential( 484 | # nn.Conv2d(num_input, num_channels, kernel_size=5, stride=1, padding=2), # 5,2/ 1,0 485 | # nn.ReLU(inplace=True), 486 | # nn.BatchNorm2d(num_channels), 487 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 488 | # nn.Dropout(), 489 | # ) 490 | self.stream1 = nn.Sequential( 491 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), 492 | nn.ReLU(inplace=True), 493 | nn.BatchNorm2d(num_channels), 494 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 495 | nn.Dropout(), 496 | ) 497 | self.stream2 = nn.Sequential( 498 | nn.Conv2d(num_input, int(num_channels/2), kernel_size=3, stride=3, padding=2, dilation=2), # 5,2/ 1,0 499 | nn.ReLU(inplace=True), 500 | nn.BatchNorm2d(int(num_channels/2)), 501 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 502 | nn.Dropout(), 503 | ) 504 | self.stream3 = nn.Sequential( 505 | nn.Conv2d(num_input, int(num_channels/2), kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0 506 | nn.ReLU(inplace=True), 507 | nn.BatchNorm2d(int(num_channels/2)), 508 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 509 | nn.Dropout(), 510 | ) 511 | self.ebl = EmbeddingBlock(featuremaps, featuremaps) 512 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 513 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 514 | for m in self.modules(): 515 | if isinstance(m, nn.Conv2d): 516 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 517 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 518 | nn.init.constant_(m.weight, 1) 519 | nn.init.constant_(m.bias, 0) 520 | elif isinstance(m, nn.Linear): 521 | nn.init.normal_(m.weight, 0, 0.01) 522 | nn.init.constant_(m.bias, 0) 523 | 524 | def forward(self, x): 525 | x1 = self.stream1(x) 526 | x2 = self.stream2(x) 527 | # x = torch.cat((x1, x2), 1) 528 | x3 = self.stream2(x) 529 | x = torch.cat((x1,x2,x3),1) 530 | x = self.ebl(x, self.classifier.weight, self.poolsize, self.classes) 531 | x = self.avgpool(x) 532 | x = torch.flatten(x, 1) 533 | x = self.classifier(x) 534 | return x -------------------------------------------------------------------------------- /MeNets_NAS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | __all__ = ['MeNet_A', 'MeNet_D', 'MeNet_W', 'MeNet_H', 'MeNet_C', 'MeNet_E'] 6 | 7 | class ConvBlock(nn.Module): 8 | """convolutional layer blocks for sequtial convolution operations""" 9 | def __init__(self, in_features, out_features, num_conv, pool=False): 10 | super(ConvBlock, self).__init__() 11 | features = [in_features] + [out_features for i in range(num_conv)] 12 | layers = [] 13 | for i in range(len(features)-1): 14 | layers.append(nn.Conv2d(in_channels=features[i], out_channels=features[i+1], kernel_size=3, padding=1, bias=True)) 15 | layers.append(nn.BatchNorm2d(num_features=features[i+1], affine=True, track_running_stats=True)) 16 | layers.append(nn.ReLU()) 17 | if pool: 18 | layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0)) 19 | self.op = nn.Sequential(*layers) 20 | def forward(self, x): 21 | return self.op(x) 22 | 23 | class RclBlock(nn.Module): 24 | """recurrent convolutional blocks""" 25 | def __init__(self, inplanes, planes): 26 | super(RclBlock, self).__init__() 27 | self.ffconv = nn.Sequential( 28 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0), 29 | nn.BatchNorm2d(planes), 30 | nn.ReLU(inplace=True) 31 | ) 32 | self.rrconv = nn.Sequential( 33 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1), 34 | nn.BatchNorm2d(planes), 35 | nn.ReLU(inplace=True) 36 | ) 37 | self.downsample = nn.Sequential( 38 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 39 | nn.Dropout() 40 | ) 41 | 42 | def forward(self, x): 43 | y = self.ffconv(x) 44 | y = self.rrconv(x + y) 45 | y = self.rrconv(x + y) 46 | out = self.downsample (y) 47 | return out 48 | 49 | class DenseBlock(nn.Module): 50 | """densely connected convolutional blocks""" 51 | def __init__(self, inplanes, planes): 52 | super(DenseBlock, self).__init__() 53 | self.conv1 = nn.Sequential( 54 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0), 55 | nn.BatchNorm2d(planes), 56 | nn.ReLU(inplace=True) 57 | ) 58 | self.conv2 = nn.Sequential( 59 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1), 60 | nn.BatchNorm2d(planes), 61 | nn.ReLU(inplace=True) 62 | ) 63 | self.conv3 = nn.Sequential( 64 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1), 65 | nn.BatchNorm2d(planes), 66 | nn.ReLU(inplace=True) 67 | ) 68 | self.downsample = nn.Sequential( 69 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 70 | nn.Dropout() 71 | ) 72 | 73 | def forward(self, x): 74 | y = self.conv1(x) 75 | z = self.conv2(x + y) 76 | # out = self.conv2(x + y + z) 77 | e = self.conv2(x + y + z) 78 | out = self.conv2(x + y + z + e) 79 | out = self.downsample (out) 80 | return out 81 | 82 | class EmbeddingBlock(nn.Module): 83 | """densely connected convolutional blocks for embedding""" 84 | def __init__(self, inplanes, planes): 85 | super(EmbeddingBlock, self).__init__() 86 | self.conv1 = nn.Sequential( 87 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0), 88 | nn.BatchNorm2d(planes), 89 | nn.ReLU(inplace=True) 90 | ) 91 | self.conv2 = nn.Sequential( 92 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1), 93 | nn.BatchNorm2d(planes), 94 | nn.ReLU(inplace=True) 95 | ) 96 | self.attenmap = SpatialAttentionBlock_P(normalize_attn=True) 97 | self.downsample = nn.Sequential( 98 | nn.MaxPool2d(kernel_size=2, stride=3, padding=0), 99 | nn.Dropout() 100 | ) 101 | 102 | def forward(self, x, w, pool_size, classes): 103 | y = self.conv1(x) 104 | y1 = self.attenmap(F.adaptive_avg_pool2d(x, (pool_size, pool_size)), w, classes) 105 | y = torch.mul(F.interpolate(y1, (y.shape[2], y.shape[3])), y) 106 | z = self.conv2(x+y) 107 | e = self.conv2(x + y + z) 108 | out = self.conv2(x + y + z + e) 109 | out = self.downsample (out) 110 | return out 111 | 112 | class EmbeddingBlock2(nn.Module): 113 | """densely connected convolutional blocks for embedding""" 114 | def __init__(self, inplanes, planes): 115 | super(EmbeddingBlock2, self).__init__() 116 | self.conv1 = nn.Sequential( 117 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0), 118 | nn.BatchNorm2d(planes), 119 | nn.ReLU(inplace=True) 120 | ) 121 | self.conv2 = nn.Sequential( 122 | nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1), 123 | nn.BatchNorm2d(planes), 124 | nn.ReLU(inplace=True) 125 | ) 126 | self.attenmap = SpatialAttentionBlock_P(normalize_attn=True) 127 | self.downsample = nn.Sequential( 128 | nn.MaxPool2d(kernel_size=2, stride=3, padding=0), 129 | nn.Dropout() 130 | ) 131 | 132 | def forward(self, x, w, pool_size, classes): 133 | y = self.conv1(x) 134 | #y1 = self.attenmap(F.adaptive_avg_pool2d(x, (pool_size, pool_size)), w, classes) 135 | #y = torch.mul(F.interpolate(y1, (y.shape[2], y.shape[3])), y) 136 | z = self.conv2(y) 137 | e = self.conv2(y + z) 138 | out = self.conv2(y + z + e) 139 | out = self.downsample (out) 140 | return out 141 | 142 | class SpatialAttentionBlock_A(nn.Module): 143 | """linear attention block for any layers""" 144 | def __init__(self, in_features, normalize_attn=True): 145 | super(SpatialAttentionBlock_A, self).__init__() 146 | self.normalize_attn = normalize_attn 147 | self.op = nn.Conv2d(in_channels=in_features, out_channels=1, kernel_size=1, padding=0, bias=False) 148 | 149 | def forward(self, l): 150 | N, C, W, H = l.size() 151 | c = self.op(l) # batch_sizex1xWxH 152 | if self.normalize_attn: 153 | a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,W,H) 154 | else: 155 | a = torch.sigmoid(c) 156 | g = torch.mul(a.expand_as(l), l) 157 | return g 158 | 159 | class SpatialAttentionBlock_P(nn.Module): 160 | """linear attention block for any layers""" 161 | def __init__(self, normalize_attn=True): 162 | super(SpatialAttentionBlock_P, self).__init__() 163 | self.normalize_attn = normalize_attn 164 | 165 | def forward(self, l, w, classes): 166 | output_cam = [] 167 | for idx in range(0,classes): 168 | weights = w[idx,:].reshape((l.shape[1], l.shape[2], l.shape[3])) 169 | cam = weights * l 170 | cam = cam.mean(dim=1,keepdim=True) 171 | cam = cam - torch.min(torch.min(cam,3,True)[0],2,True)[0] 172 | cam = cam / torch.max(torch.max(cam,3,True)[0],2,True)[0] 173 | output_cam.append(cam) 174 | output = torch.cat(output_cam, dim=1) 175 | output = output.mean(dim=1,keepdim=True) 176 | return output 177 | 178 | class SpatialAttentionBlock_F(nn.Module): 179 | """linear attention block for first layer""" 180 | def __init__(self, normalize_attn=True): 181 | super(SpatialAttentionBlock_F, self).__init__() 182 | self.normalize_attn = normalize_attn 183 | 184 | def forward(self, l, w, classes): 185 | output_cam = [] 186 | for idx in range(0,classes): 187 | weights = w[idx,:].reshape((-1, l.shape[2], l.shape[3])) 188 | weights = weights.mean(dim=0,keepdim=True) 189 | cam = weights * l 190 | cam = cam.mean(dim=1,keepdim=True) 191 | cam = cam - torch.min(torch.min(cam,3,True)[0],2,True)[0] 192 | cam = cam / torch.max(torch.max(cam,3,True)[0],2,True)[0] 193 | output_cam.append(cam) 194 | output = torch.cat(output_cam, dim=1) 195 | output = output.mean(dim=1,keepdim=True) 196 | return output 197 | 198 | def MakeLayer(block, planes, blocks): 199 | layers = [] 200 | for _ in range(0, blocks): 201 | layers.append(block(planes, planes)) 202 | return nn.Sequential(*layers) 203 | 204 | class MeNet_A(nn.Module): 205 | """menet networks with adding attention unit 206 | """ 207 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5, model_version=3): 208 | super(MeNet_A, self).__init__() 209 | self.version = model_version 210 | self.classes = num_classes 211 | self.conv1 = nn.Sequential( 212 | nn.Conv2d(num_input, featuremaps, kernel_size=5, stride=1, padding=0), 213 | nn.BatchNorm2d(featuremaps), 214 | nn.ReLU(inplace=True), 215 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 216 | nn.Dropout(), 217 | ) 218 | self.rcls = MakeLayer(RclBlock, featuremaps, num_layers) 219 | self.attenmap = SpatialAttentionBlock_P(normalize_attn=True) 220 | self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 221 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 222 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 223 | 224 | for m in self.modules(): 225 | if isinstance(m, nn.Conv2d): 226 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 227 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 228 | nn.init.constant_(m.weight, 1) 229 | nn.init.constant_(m.bias, 0) 230 | elif isinstance(m, nn.Linear): 231 | nn.init.normal_(m.weight, 0, 0.01) 232 | nn.init.constant_(m.bias, 0) 233 | 234 | def forward(self, x): 235 | if self.version == 1: 236 | x = self.conv1(x) 237 | x = self.attenmap(x) 238 | x = self.rcls(x) 239 | x = self.avgpool(x) 240 | if self.version == 2: 241 | x = self.conv1(x) 242 | x = self.attenmap(x) 243 | x = self.rcls(x) 244 | x = self.avgpool(x) 245 | elif self.version == 3: 246 | x = self.conv1(x) 247 | y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 248 | x = self.rcls(x) 249 | x = self.avgpool(x) 250 | x = x * y 251 | x = torch.flatten(x, 1) 252 | x = self.classifier(x) 253 | return x 254 | 255 | class MeNet_D(nn.Module): 256 | """menet networks with dense connection 257 | """ 258 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 259 | super(MeNet_D, self).__init__() 260 | self.conv1 = nn.Sequential( 261 | nn.Conv2d(num_input, featuremaps, kernel_size=5, stride=1, padding=0), 262 | nn.BatchNorm2d(featuremaps), 263 | nn.ReLU(inplace=True), 264 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 265 | nn.Dropout(), 266 | ) 267 | self.dbl = MakeLayer(DenseBlock, featuremaps, num_layers) 268 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 269 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 270 | for m in self.modules(): 271 | if isinstance(m, nn.Conv2d): 272 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 273 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 274 | nn.init.constant_(m.weight, 1) 275 | nn.init.constant_(m.bias, 0) 276 | elif isinstance(m, nn.Linear): 277 | nn.init.normal_(m.weight, 0, 0.01) 278 | nn.init.constant_(m.bias, 0) 279 | 280 | def forward(self, x): 281 | x = self.conv1(x) 282 | x = self.dbl(x) 283 | x = self.avgpool(x) 284 | x = torch.flatten(x, 1) 285 | x = self.classifier(x) 286 | return x 287 | 288 | class MeNet_W(nn.Module): 289 | """menet networks with wide expansion 290 | """ 291 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 292 | super(MeNet_W, self).__init__() 293 | num_channels = int(featuremaps/2) 294 | self.stream1 = nn.Sequential( 295 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), 296 | nn.ReLU(inplace=True), 297 | nn.BatchNorm2d(num_channels), 298 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 299 | nn.Dropout(), 300 | ) 301 | self.stream2 = nn.Sequential( 302 | # nn.Conv2d(num_input, num_channels, kernel_size=5, stride=3, padding=2), 303 | nn.Conv2d(num_input, int(num_channels/2), kernel_size=3, stride=3, padding=2, dilation=2), # 5,2/ 1,0 304 | nn.ReLU(inplace=True), 305 | nn.BatchNorm2d(int(num_channels/2)), 306 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 307 | nn.Dropout(), 308 | ) 309 | self.stream3 = nn.Sequential( 310 | # nn.Conv2d(num_input, num_channels, kernel_size=5, stride=3, padding=2), 311 | nn.Conv2d(num_input, int(num_channels/2), kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0 312 | nn.ReLU(inplace=True), 313 | nn.BatchNorm2d(int(num_channels/2)), 314 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 315 | nn.Dropout(), 316 | ) 317 | self.rcls = MakeLayer(RclBlock, featuremaps, num_layers) 318 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 319 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 320 | for m in self.modules(): 321 | if isinstance(m, nn.Conv2d): 322 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 323 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 324 | nn.init.constant_(m.weight, 1) 325 | nn.init.constant_(m.bias, 0) 326 | elif isinstance(m, nn.Linear): 327 | nn.init.normal_(m.weight, 0, 0.01) 328 | nn.init.constant_(m.bias, 0) 329 | 330 | def forward(self, x): 331 | x1 = self.stream1(x) 332 | x2 = self.stream2(x) 333 | x3 = self.stream3(x) 334 | x = torch.cat((x1,x2,x3),1) 335 | x = self.rcls(x) 336 | x = self.avgpool(x) 337 | x = torch.flatten(x, 1) 338 | x = self.classifier(x) 339 | return x 340 | 341 | class MeNet_H(nn.Module): 342 | """menet networks with hybrid modules 343 | """ 344 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 345 | super(MeNet_H, self).__init__() 346 | self.classes = num_classes 347 | num_channels = int(featuremaps/2) 348 | self.stream1 = nn.Sequential( 349 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3 350 | nn.ReLU(inplace=True), 351 | nn.BatchNorm2d(num_channels), 352 | nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 353 | nn.Dropout(), 354 | ) 355 | self.stream2 = nn.Sequential( 356 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=3, dilation=3), # 5,2/ 1,0 357 | nn.ReLU(inplace=True), 358 | nn.BatchNorm2d(num_channels), 359 | nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 360 | nn.Dropout(), 361 | ) 362 | self.dbl = MakeLayer(DenseBlock, featuremaps, num_layers) 363 | self.rcls = MakeLayer(RclBlock, featuremaps, num_layers) 364 | self.attenmap = SpatialAttentionBlock_P(normalize_attn=True) 365 | self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 366 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 367 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 368 | for m in self.modules(): 369 | if isinstance(m, nn.Conv2d): 370 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 371 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 372 | nn.init.constant_(m.weight, 1) 373 | nn.init.constant_(m.bias, 0) 374 | elif isinstance(m, nn.Linear): 375 | nn.init.normal_(m.weight, 0, 0.01) 376 | nn.init.constant_(m.bias, 0) 377 | 378 | def forward(self, x): 379 | x1 = self.stream1(x) 380 | x2 = self.stream2(x) 381 | x = torch.cat((x1,x2),1) 382 | y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 383 | x = self.dbl(x) 384 | x = self.avgpool(x) 385 | x = x * y 386 | x = torch.flatten(x, 1) 387 | x = self.classifier(x) 388 | return x 389 | 390 | 391 | 392 | class MeNet_CS(nn.Module): 393 | """menet networks with cascaded modules with searching 394 | """ 395 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 396 | super(MeNet_CS, self).__init__() 397 | self.classes = num_classes 398 | num_channels = int(featuremaps/2) 399 | self.archi = nn.Parameter(torch.randn(2,2)) 400 | # self.stream1 = nn.Sequential( 401 | # nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3 402 | # nn.ReLU(inplace=True), 403 | # nn.BatchNorm2d(num_channels), 404 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 405 | # nn.Dropout(), 406 | # ) 407 | # self.stream2 = nn.Sequential( 408 | # nn.Conv2d(num_input, num_channels, kernel_size=5, stride=1, padding=2), # 5,2/ 1,0 409 | # nn.ReLU(inplace=True), 410 | # nn.BatchNorm2d(num_channels), 411 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 412 | # nn.Dropout(), 413 | # ) 414 | self.stream1 = nn.Sequential( 415 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3 416 | nn.ReLU(inplace=True), 417 | nn.BatchNorm2d(num_channels), 418 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 419 | nn.Dropout(), 420 | ) 421 | self.stream2 = nn.Sequential( 422 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0 423 | nn.ReLU(inplace=True), 424 | nn.BatchNorm2d(num_channels), 425 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 426 | nn.Dropout(), 427 | ) 428 | self.conv1 = nn.Sequential( 429 | nn.Conv2d(featuremaps, featuremaps, kernel_size=1, stride=1, padding=0), 430 | nn.BatchNorm2d(featuremaps), 431 | nn.ReLU(inplace=True) 432 | ) 433 | self.conv2 = nn.Sequential( 434 | nn.Conv2d(featuremaps, featuremaps, kernel_size=3, stride=1, padding=1), 435 | nn.BatchNorm2d(featuremaps), 436 | nn.ReLU(inplace=True) 437 | ) 438 | self.downsample = nn.Sequential( 439 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 440 | nn.Dropout() 441 | ) 442 | 443 | 444 | self.softmax = nn.Softmax(0) 445 | self.attenmap = SpatialAttentionBlock_F(normalize_attn=True) 446 | self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 447 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 448 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 449 | 450 | nn.init.constant(self.archi, 0.5) 451 | 452 | for m in self.modules(): 453 | if isinstance(m, nn.Conv2d): 454 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 455 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 456 | nn.init.constant_(m.weight, 1) 457 | nn.init.constant_(m.bias, 0) 458 | elif isinstance(m, nn.Linear): 459 | nn.init.normal_(m.weight, 0, 0.01) 460 | nn.init.constant_(m.bias, 0) 461 | 462 | def forward(self, x): 463 | W = self.softmax(self.archi) 464 | #W = self.archi 465 | #M for attention mask 466 | M1 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 467 | 468 | x1 = self.stream1(x) 469 | x2 = self.stream2(x) 470 | x = torch.cat((x1,x2),1) 471 | 472 | x1 = torch.mul(F.interpolate(M1,(x.shape[2],x.shape[3])), x) 473 | 474 | #x = W[0][0]*x+W[0][1]*x1 475 | x = x+W[0][1]*x1 476 | #Second Ateention 477 | M2 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 478 | y = self.conv1(x) 479 | y1 = torch.mul(F.interpolate(M2,(y.shape[2],y.shape[3])), y) 480 | 481 | #y = W[1][0]*y + W[1][1]*y1 482 | y = y + W[1][1]*y1 483 | 484 | #Third Ateention 485 | M3 = self.attenmap(self.downsampling(y), self.classifier.weight, self.classes) 486 | z = self.conv2(x+y) 487 | z1 = torch.mul(F.interpolate(M3,(z.shape[2],z.shape[3])), z) 488 | 489 | #z = W[2][0]*z + W[2][1]*z1 490 | z = z #+ W[2][1]*z1 491 | #Forth Ateention 492 | M4 = self.attenmap(self.downsampling(z), self.classifier.weight, self.classes) 493 | e = self.conv2(x+y+z) 494 | e1 = torch.mul(F.interpolate(M4,(e.shape[2],e.shape[3])), e) 495 | 496 | e = e #+W[3][1]*e1 497 | #e = W[3][0]*e+W[3][1]*e1 498 | 499 | #Fiveth Ateention 500 | M5 = self.attenmap(self.downsampling(e), self.classifier.weight, self.classes) 501 | out = self.conv2(x+y+z+e) 502 | out1 = torch.mul(F.interpolate(M5,(out.shape[2],out.shape[3])), out) 503 | 504 | #out = W[4][0]*out+W[4][1]*out1 505 | out = out #+W[4][1]*out1 506 | 507 | out = self.downsample(out) 508 | 509 | x = self.avgpool(out) 510 | x = torch.flatten(x, 1) 511 | x = self.classifier(x) 512 | return x 513 | 514 | 515 | class MeNet_CS2(nn.Module): 516 | """menet networks with cascaded modules with searching 517 | """ 518 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 519 | super(MeNet_CS2, self).__init__() 520 | self.classes = num_classes 521 | num_channels = int(featuremaps/2) 522 | self.archi = nn.Parameter(torch.randn(4,2)) 523 | # self.stream1 = nn.Sequential( 524 | # nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3 525 | # nn.ReLU(inplace=True), 526 | # nn.BatchNorm2d(num_channels), 527 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 528 | # nn.Dropout(), 529 | # ) 530 | # self.stream2 = nn.Sequential( 531 | # nn.Conv2d(num_input, num_channels, kernel_size=5, stride=1, padding=2), # 5,2/ 1,0 532 | # nn.ReLU(inplace=True), 533 | # nn.BatchNorm2d(num_channels), 534 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 535 | # nn.Dropout(), 536 | # ) 537 | self.stream1 = nn.Sequential( 538 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3 539 | nn.ReLU(inplace=True), 540 | nn.BatchNorm2d(num_channels), 541 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 542 | nn.Dropout(), 543 | ) 544 | self.stream2 = nn.Sequential( 545 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0 546 | nn.ReLU(inplace=True), 547 | nn.BatchNorm2d(num_channels), 548 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 549 | nn.Dropout(), 550 | ) 551 | self.conv1 = nn.Sequential( 552 | nn.Conv2d(featuremaps, featuremaps, kernel_size=1, stride=1, padding=0), 553 | nn.BatchNorm2d(featuremaps), 554 | nn.ReLU(inplace=True) 555 | ) 556 | self.conv2 = nn.Sequential( 557 | nn.Conv2d(featuremaps, featuremaps, kernel_size=3, stride=1, padding=1), 558 | nn.BatchNorm2d(featuremaps), 559 | nn.ReLU(inplace=True) 560 | ) 561 | self.downsample = nn.Sequential( 562 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 563 | nn.Dropout() 564 | ) 565 | 566 | 567 | self.softmax = nn.Softmax(-1) 568 | self.attenmap = SpatialAttentionBlock_F(normalize_attn=True) 569 | self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 570 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 571 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 572 | 573 | nn.init.constant(self.archi, 0.5) 574 | 575 | for m in self.modules(): 576 | if isinstance(m, nn.Conv2d): 577 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 578 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 579 | nn.init.constant_(m.weight, 1) 580 | nn.init.constant_(m.bias, 0) 581 | elif isinstance(m, nn.Linear): 582 | nn.init.normal_(m.weight, 0, 0.01) 583 | nn.init.constant_(m.bias, 0) 584 | 585 | def forward(self, x): 586 | W = self.softmax(self.archi) 587 | #W = self.archi 588 | #M for attention mask 589 | M1 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 590 | 591 | x1 = self.stream1(x) 592 | x2 = self.stream2(x) 593 | x = torch.cat((x1,x2),1) 594 | 595 | #Second Ateention 596 | M2 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 597 | y = self.conv1(x) 598 | y1 = torch.mul(F.interpolate(M1,(y.shape[2],y.shape[3])), y)# Here we use M1 599 | 600 | y = W[0][0]*y + W[0][1]*y1 601 | 602 | #Third Ateention 603 | M3 = self.attenmap(self.downsampling(y), self.classifier.weight, self.classes) 604 | z = self.conv2(x+y) 605 | z1 = torch.mul(F.interpolate(M2,(z.shape[2],z.shape[3])), z) 606 | 607 | z = W[1][0]*z + W[1][1]*z1 608 | #z = z #+ W[2][1]*z1 609 | #Forth Ateention 610 | M4 = self.attenmap(self.downsampling(z), self.classifier.weight, self.classes) 611 | e = self.conv2(x+y+z) 612 | e1 = torch.mul(F.interpolate(M3,(e.shape[2],e.shape[3])), e) 613 | 614 | e = W[2][0]*e+W[2][1]*e1 615 | 616 | #Fiveth Ateention 617 | #M5 = self.attenmap(self.downsampling(e), self.classifier.weight, self.classes) 618 | out = self.conv2(x+y+z+e) 619 | out1 = torch.mul(F.interpolate(M4,(out.shape[2],out.shape[3])), out) 620 | 621 | out = W[3][0]*out+W[3][1]*out1 622 | #out = out #+W[4][1]*out1 623 | 624 | out = self.downsample(out) 625 | 626 | x = self.avgpool(out) 627 | x = torch.flatten(x, 1) 628 | x = self.classifier(x) 629 | return x 630 | 631 | 632 | class MeNet_CS3(nn.Module): 633 | """menet networks with cascaded modules with searching 634 | """ 635 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 636 | super(MeNet_CS3, self).__init__() 637 | self.classes = num_classes 638 | num_channels = int(featuremaps/2) 639 | self.archi = nn.Parameter(torch.randn(3,2)) 640 | # self.stream1 = nn.Sequential( 641 | # nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3 642 | # nn.ReLU(inplace=True), 643 | # nn.BatchNorm2d(num_channels), 644 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 645 | # nn.Dropout(), 646 | # ) 647 | # self.stream2 = nn.Sequential( 648 | # nn.Conv2d(num_input, num_channels, kernel_size=5, stride=1, padding=2), # 5,2/ 1,0 649 | # nn.ReLU(inplace=True), 650 | # nn.BatchNorm2d(num_channels), 651 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 652 | # nn.Dropout(), 653 | # ) 654 | self.stream1 = nn.Sequential( 655 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3 656 | nn.ReLU(inplace=True), 657 | nn.BatchNorm2d(num_channels), 658 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 659 | nn.Dropout(), 660 | ) 661 | self.stream2 = nn.Sequential( 662 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0 663 | nn.ReLU(inplace=True), 664 | nn.BatchNorm2d(num_channels), 665 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 666 | nn.Dropout(), 667 | ) 668 | self.conv1 = nn.Sequential( 669 | nn.Conv2d(featuremaps, featuremaps, kernel_size=1, stride=1, padding=0), 670 | nn.BatchNorm2d(featuremaps), 671 | nn.ReLU(inplace=True) 672 | ) 673 | self.conv2 = nn.Sequential( 674 | nn.Conv2d(featuremaps, featuremaps, kernel_size=3, stride=1, padding=1), 675 | nn.BatchNorm2d(featuremaps), 676 | nn.ReLU(inplace=True) 677 | ) 678 | self.downsample = nn.Sequential( 679 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 680 | nn.Dropout() 681 | ) 682 | 683 | 684 | self.softmax = nn.Softmax(-1) 685 | self.attenmap = SpatialAttentionBlock_F(normalize_attn=True) 686 | self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 687 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 688 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 689 | 690 | nn.init.constant(self.archi, 0.5) 691 | 692 | for m in self.modules(): 693 | if isinstance(m, nn.Conv2d): 694 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 695 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 696 | nn.init.constant_(m.weight, 1) 697 | nn.init.constant_(m.bias, 0) 698 | elif isinstance(m, nn.Linear): 699 | nn.init.normal_(m.weight, 0, 0.01) 700 | nn.init.constant_(m.bias, 0) 701 | 702 | def forward(self, x): 703 | W = self.softmax(self.archi) 704 | #W = self.archi 705 | #M for attention mask 706 | M1 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 707 | 708 | x1 = self.stream1(x) 709 | x2 = self.stream2(x) 710 | x = torch.cat((x1,x2),1) 711 | 712 | #Second Ateention 713 | M2 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 714 | y = self.conv1(x) 715 | 716 | #Third Ateention 717 | M3 = self.attenmap(self.downsampling(y), self.classifier.weight, self.classes) 718 | z = self.conv2(x+y) 719 | z1 = torch.mul(F.interpolate(M1,(z.shape[2],z.shape[3])), z) 720 | 721 | z = W[0][0]*z + W[0][1]*z1 722 | #z = z #+ W[2][1]*z1 723 | #Forth Ateention 724 | #M4 = self.attenmap(self.downsampling(z), self.classifier.weight, self.classes) 725 | e = self.conv2(x+y+z) 726 | e1 = torch.mul(F.interpolate(M2,(e.shape[2],e.shape[3])), e) 727 | 728 | e = W[1][0]*e+W[1][1]*e1 729 | 730 | #Fiveth Ateention 731 | #M5 = self.attenmap(self.downsampling(e), self.classifier.weight, self.classes) 732 | out = self.conv2(x+y+z+e) 733 | out1 = torch.mul(F.interpolate(M3,(out.shape[2],out.shape[3])), out) 734 | 735 | out = W[2][0]*out+W[2][1]*out1 736 | #out = out #+W[4][1]*out1 737 | 738 | out = self.downsample(out) 739 | 740 | x = self.avgpool(out) 741 | x = torch.flatten(x, 1) 742 | x = self.classifier(x) 743 | return x 744 | 745 | class MeNet_C(nn.Module): 746 | """menet networks with cascaded modules 747 | """ 748 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 749 | super(MeNet_C, self).__init__() 750 | self.classes = num_classes 751 | num_channels = int(featuremaps/2) 752 | self.stream1 = nn.Sequential( 753 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3 754 | nn.ReLU(inplace=True), 755 | nn.BatchNorm2d(num_channels), 756 | nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 757 | nn.Dropout(), 758 | ) 759 | self.stream2 = nn.Sequential( 760 | nn.Conv2d(num_input, num_channels, kernel_size=5, stride=1, padding=2), # 5,2/ 1,0 761 | nn.ReLU(inplace=True), 762 | nn.BatchNorm2d(num_channels), 763 | nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 764 | nn.Dropout(), 765 | ) 766 | self.dbl = MakeLayer(DenseBlock, featuremaps, num_layers) 767 | # self.attenmap = SpatialAttentionBlock_P(normalize_attn=True) 768 | self.attenmap = SpatialAttentionBlock_F(normalize_attn=True) 769 | self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 770 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 771 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 772 | for m in self.modules(): 773 | if isinstance(m, nn.Conv2d): 774 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 775 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 776 | nn.init.constant_(m.weight, 1) 777 | nn.init.constant_(m.bias, 0) 778 | elif isinstance(m, nn.Linear): 779 | nn.init.normal_(m.weight, 0, 0.01) 780 | nn.init.constant_(m.bias, 0) 781 | 782 | def forward(self, x): 783 | y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 784 | x1 = self.stream1(x) 785 | x2 = self.stream2(x) 786 | x = torch.cat((x1,x2),1) 787 | # y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes) 788 | x = torch.mul(F.interpolate(y,(x.shape[2],x.shape[3])), x) 789 | x = self.dbl(x) 790 | x = self.avgpool(x) 791 | x = torch.flatten(x, 1) 792 | x = self.classifier(x) 793 | return x 794 | 795 | class MeNet_E(nn.Module): 796 | """menet networks with embedded modules 797 | """ 798 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 799 | super(MeNet_E, self).__init__() 800 | self.classes = num_classes 801 | self.poolsize = pool_size 802 | num_channels = int(featuremaps/2) 803 | self.stream1 = nn.Sequential( 804 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3 805 | nn.ReLU(inplace=True), 806 | nn.BatchNorm2d(num_channels), 807 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 808 | nn.Dropout(), 809 | ) 810 | self.stream2 = nn.Sequential( 811 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0 812 | nn.ReLU(inplace=True), 813 | nn.BatchNorm2d(num_channels), 814 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 815 | nn.Dropout(), 816 | ) 817 | self.ebl = EmbeddingBlock(featuremaps, featuremaps) 818 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 819 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 820 | for m in self.modules(): 821 | if isinstance(m, nn.Conv2d): 822 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 823 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 824 | nn.init.constant_(m.weight, 1) 825 | nn.init.constant_(m.bias, 0) 826 | elif isinstance(m, nn.Linear): 827 | nn.init.normal_(m.weight, 0, 0.01) 828 | nn.init.constant_(m.bias, 0) 829 | 830 | def forward(self, x): 831 | x1 = self.stream1(x) 832 | x2 = self.stream2(x) 833 | x = torch.cat((x1,x2),1) 834 | x = self.ebl(x, self.classifier.weight, self.poolsize, self.classes) 835 | x = self.avgpool(x) 836 | x = torch.flatten(x, 1) 837 | x = self.classifier(x) 838 | return x 839 | 840 | 841 | # class MeNet_ES(nn.Module): 842 | # """menet networks with embedded modules by searching 843 | # """ 844 | # def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 845 | # super(MeNet_ES, self).__init__() 846 | # self.classes = num_classes 847 | # self.poolsize = pool_size 848 | # num_channels = featuremaps 849 | # self.archi = nn.Parameter(torch.randn(3)) 850 | # self.stream1 = nn.Sequential( 851 | # nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3 852 | # nn.ReLU(inplace=True), 853 | # nn.BatchNorm2d(num_channels), 854 | # # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 855 | # nn.Dropout(), 856 | # ) 857 | # self.stream2 = nn.Sequential( 858 | # nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1, dilation=2), # 5,2/ 1,0 859 | # nn.ReLU(inplace=True), 860 | # nn.BatchNorm2d(num_channels), 861 | # # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 862 | # nn.Dropout(), 863 | # ) 864 | # self.stream3 = nn.Sequential( 865 | # nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0 866 | # nn.ReLU(inplace=True), 867 | # nn.BatchNorm2d(num_channels), 868 | # # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 869 | # nn.Dropout(), 870 | # ) 871 | 872 | # self.bn = nn.BatchNorm2d(num_channels) 873 | # self.softmax = nn.Softmax(0) 874 | # self.ebl = EmbeddingBlock(featuremaps, featuremaps) 875 | # self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 876 | # self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 877 | # nn.init.constant(self.archi, 0.333) 878 | # for m in self.modules(): 879 | # if isinstance(m, nn.Conv2d): 880 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 881 | # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 882 | # nn.init.constant_(m.weight, 1) 883 | # nn.init.constant_(m.bias, 0) 884 | # elif isinstance(m, nn.Linear): 885 | # nn.init.normal_(m.weight, 0, 0.01) 886 | # nn.init.constant_(m.bias, 0) 887 | 888 | # def forward(self, x): 889 | # x1 = self.stream1(x) 890 | # x2 = self.stream2(x) 891 | # x3 = self.stream3(x) 892 | # #x = torch.cat((x1,x2,x3),1) 893 | # W = self.softmax(self.archi) 894 | # #print(W) 895 | # x = W[0]*x1+ W[1]*x2+ W[2]*x3 896 | # x = self.bn(x) 897 | # x = self.ebl(x, self.classifier.weight, self.poolsize, self.classes) 898 | # x = self.avgpool(x) 899 | # x = torch.flatten(x, 1) 900 | # x = self.classifier(x) 901 | # return x 902 | 903 | class MeNet_ES(nn.Module): 904 | """menet networks with embedded modules by searching 905 | """ 906 | def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5): 907 | super(MeNet_ES, self).__init__() 908 | self.classes = num_classes 909 | self.poolsize = pool_size 910 | num_channels = int(featuremaps/2) 911 | self.archi = nn.Parameter(torch.randn(2)) 912 | self.stream1 = nn.Sequential( 913 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3 914 | nn.ReLU(inplace=True), 915 | nn.BatchNorm2d(num_channels), 916 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 917 | nn.Dropout(), 918 | ) 919 | self.stream2 = nn.Sequential( 920 | nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0 921 | nn.ReLU(inplace=True), 922 | nn.BatchNorm2d(num_channels), 923 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 924 | nn.Dropout(), 925 | ) 926 | 927 | self.stream21 = nn.Sequential( 928 | nn.Conv2d(featuremaps, num_channels, kernel_size=3, stride=3, padding=1), # 1->3 929 | nn.ReLU(inplace=True), 930 | nn.BatchNorm2d(num_channels), 931 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 932 | nn.Dropout(), 933 | ) 934 | self.stream22 = nn.Sequential( 935 | nn.Conv2d(featuremaps, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0 936 | nn.ReLU(inplace=True), 937 | nn.BatchNorm2d(num_channels), 938 | # nn.MaxPool2d(kernel_size=3, stride=3, padding=1), 939 | nn.Dropout(), 940 | ) 941 | 942 | #self.bn = nn.BatchNorm2d(num_channels) 943 | self.softmax = nn.Softmax(0) 944 | self.ebl = EmbeddingBlock2(num_input, featuremaps) 945 | self.ebl2 = EmbeddingBlock(featuremaps, featuremaps) 946 | self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size)) 947 | self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes) 948 | nn.init.constant(self.archi, 0.5) 949 | for m in self.modules(): 950 | if isinstance(m, nn.Conv2d): 951 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 952 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 953 | nn.init.constant_(m.weight, 1) 954 | nn.init.constant_(m.bias, 0) 955 | elif isinstance(m, nn.Linear): 956 | nn.init.normal_(m.weight, 0, 0.01) 957 | nn.init.constant_(m.bias, 0) 958 | 959 | def forward(self, x): 960 | x1 = self.stream1(x) 961 | x2 = self.stream2(x) 962 | 963 | x = self.ebl(x, self.classifier.weight, self.poolsize, self.classes) 964 | x1 = torch.cat((x1,x2),1) 965 | 966 | W = self.softmax(self.archi) 967 | x = W[0]*x+ W[1]*x1 968 | 969 | 970 | # 971 | y = self.ebl2(x, self.classifier.weight, self.poolsize, self.classes) 972 | y1 = self.stream21(x) 973 | y2 = self.stream22(x) 974 | y1 = torch.cat((y1,y2),1) 975 | 976 | x = W[1]*y+ W[0]*y1 977 | x = self.avgpool(x) 978 | x = torch.flatten(x, 1) 979 | x = self.classifier(x) 980 | print(W) 981 | return x -------------------------------------------------------------------------------- /Metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sklearn.metrics as metrics 3 | 4 | class mAP: 5 | """mean average precision: 1/|Right| * sum( P@k ) 6 | """ 7 | def __init__(self): 8 | self.type = 0 9 | 10 | def eval_scalar(self,pred_s, true_s): 11 | if pred_s.shape[1] > 1 or true_s.shape[1] > 1: 12 | print('Inputs need to be a torch scalar!') 13 | 14 | def eval_vector(self,pred_mat, true_mat, bins = None ): 15 | """mean average precision with input of vectored labels: 16 | pred_mat: N*C matrix 17 | true_mat: N*C matrix 18 | """ 19 | if not (torch.is_tensor(pred_mat) and torch.is_tensor(true_mat)): 20 | print('Inputs need to be a torch tensor!') 21 | 22 | num_classes = pred_mat.shape[1] 23 | num_samples = pred_mat.shape[0] 24 | if bins is None: 25 | K = num_samples 26 | else: 27 | K = bins 28 | pred_sorted, idx_mat = torch.sort(pred_mat,dim=0, descending=True) 29 | precisions = torch.zeros(num_classes) 30 | for i in range(num_classes): 31 | idx = idx_mat[:,i] 32 | # x = true_mat[idx,i] 33 | x = torch.index_select(true_mat[:,i],0,idx) 34 | y = torch.cumsum(x,dim=0) 35 | num = torch.FloatTensor(range(num_samples))+1 36 | y /= num 37 | precisions[i] = torch.mean(y[:K]) 38 | 39 | map = torch.mean(precisions) 40 | 41 | return map 42 | 43 | def eval_matrix(self,pred_mat, true_mat, bins = None ): 44 | """mean average precision with input of vectored labels: 45 | pred_mat: N*C_1 matrix 46 | true_mat: N*C_2 matrix 47 | """ 48 | if not (torch.is_tensor(pred_mat) and torch.is_tensor(true_mat)): 49 | print('Inputs need to be a torch tensor!') 50 | 51 | num_bins = pred_mat.shape[1] 52 | num_samples = pred_mat.shape[0] 53 | num_classes = true_mat.shape[1] 54 | if bins is None: 55 | K = num_samples 56 | else: 57 | K = bins 58 | 59 | # calculating similarity matrix 60 | pred_s = pred_mat.mm(pred_mat.t()) 61 | pred_s = torch.div(pred_s,torch.diag(pred_s)) 62 | true_s = true_mat.mm(true_mat.t()) 63 | idx_rm = [i for i, v in enumerate(torch.diag(true_s)) if v == 0] 64 | # np.savetxt(os.path.join('data', 'tmp.csv'), torch.diag(true_s).numpy(), fmt="%d") 65 | true_s = torch.div(true_s, torch.diag(true_s)) 66 | pred_sorted, idx_mat = torch.sort(pred_s,dim=0, descending=True) 67 | 68 | precisions = torch.zeros(num_samples) 69 | for i in set(range(num_samples))-set(idx_rm): 70 | idx = idx_mat[:,i] 71 | x = torch.index_select(true_s[:,i],0,idx) 72 | y = torch.cumsum(x,dim=0) 73 | num = torch.FloatTensor(range(num_samples))+1 74 | y /= num 75 | precisions[i] = torch.mean(y[:K]) 76 | 77 | map = torch.mean(precisions) 78 | 79 | return map 80 | 81 | class accuracy: 82 | """accuracy: 83 | """ 84 | def __init__(self): 85 | self.type = 0 86 | 87 | def eval(self,pred_v, true_v): 88 | # calulate the weighted accuracy or unbalanced accuracy 89 | idx_a = [i for i, value in enumerate(pred_v) if pred_v[i] == true_v[i]] 90 | acc_weighted = float(len(idx_a))/float(len(pred_v)) 91 | # calculate the unweighted accuracy or balanced accuracy 92 | labels = torch.unique(true_v) 93 | acc = torch.zeros(len(labels)) 94 | for i in range(len(labels)): 95 | idx_c = [j for j in range(len(true_v)) if true_v[j] == labels[i]] 96 | acc[i] = torch.sum(pred_v[idx_c] == true_v[idx_c]).double()/float(len(idx_c)) 97 | acc_unweighted = torch.mean(acc) 98 | return acc_weighted, acc_unweighted 99 | 100 | class f1score: 101 | """f1score: weighted and unweighted 102 | """ 103 | 104 | def __init__(self): 105 | self.type = 0 106 | 107 | def eval(self, pred_v, true_v): 108 | # calulate the weighted f1 score 109 | f1 = metrics.f1_score(true_v.float(), pred_v.float(), average='micro') 110 | f1_weighted = metrics.f1_score(true_v.float(), pred_v.float(), average='macro') 111 | return f1, f1_weighted -------------------------------------------------------------------------------- /ModelTrain_Final.py: -------------------------------------------------------------------------------- 1 | import os, sys, datetime, time, random, argparse, copy 2 | 3 | import torch 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import torchvision 7 | from torchvision import transforms 8 | 9 | from Datasets import MEGC2019_SI as MEGC2019 10 | import MeNets, LossFunctions 11 | import Metrics as metrics 12 | 13 | def arg_process(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--dataversion', type=int, default=66, help='the version of input data') 16 | parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs') 17 | parser.add_argument('--gpuid', default='cuda:0', help='the gpu index for training') 18 | parser.add_argument('--learningrate', type=float, default=0.0005, help='the learning rate for training') 19 | parser.add_argument('--modelname', default='rcn_a', help='the model architecture') 20 | parser.add_argument('--modelversion', type=int, default=3, help='the version of created model') 21 | parser.add_argument('--batchsize', type=int, default=64, help='the batch size for training') 22 | parser.add_argument('--featuremap', type=int, default=64, help='the feature map size') 23 | parser.add_argument('--poolsize', type=int, default=7, help='the average pooling size') 24 | parser.add_argument('--lossfunction', default='crossentropy', help='the loss functions') 25 | args = parser.parse_args() 26 | return args 27 | 28 | def train_model(model, dataloaders, criterion, optimizer, device='cpu', num_epochs=25): 29 | since = time.time() 30 | # best_model_wts = copy.deepcopy(model.state_dict()) 31 | best_acc = 0.0 32 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) 33 | for epoch in range(num_epochs): 34 | print('\tEpoch {}/{}'.format(epoch, num_epochs - 1)) 35 | print('\t'+'-' * 10) 36 | # Each epoch has a training 37 | model.train() # Set model to training mode 38 | running_loss = 0.0 39 | running_corrects = 0 40 | # Iterate over data 41 | for j, samples in enumerate(dataloaders): 42 | inputs, class_labels = samples["data"], samples["class_label"] 43 | inputs = torch.FloatTensor(inputs).to(device) 44 | class_labels = class_labels.to(device) 45 | # zero the parameter gradients 46 | optimizer.zero_grad() 47 | # forward to get model outputs and calculate loss 48 | output_class = model(inputs) 49 | loss = criterion(output_class, class_labels) 50 | # backward 51 | loss.backward() 52 | optimizer.step() 53 | # statistics 54 | _, predicted = torch.max(output_class.data,1) 55 | running_loss += loss.item() * inputs.size(0) 56 | running_corrects += torch.sum(predicted == class_labels) 57 | 58 | epoch_loss = running_loss / len(dataloaders.dataset) 59 | epoch_acc = running_corrects.double() / len(dataloaders.dataset) 60 | print('\t{} Loss: {:.4f} Acc: {:.4f}'.format('Train', epoch_loss, epoch_acc)) 61 | 62 | time_elapsed = time.time() - since 63 | print('\tTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 64 | 65 | return model 66 | 67 | def test_model(model, dataloaders, device): 68 | 69 | model.eval() 70 | num_samples = len(dataloaders.dataset) 71 | predVec = torch.zeros(num_samples) 72 | labelVec = torch.zeros(num_samples) 73 | start_idx = 0 74 | end_idx = 0 75 | for j, samples in enumerate(dataloaders): 76 | inputs, class_labels = samples['data'], samples['class_label'] 77 | inputs = torch.FloatTensor(inputs).to(device) 78 | # update the index of ending point 79 | end_idx = start_idx + inputs.shape[0] 80 | output_class = model(inputs) 81 | _, predicted = torch.max(output_class.data, 1) 82 | predVec[start_idx:end_idx] = predicted 83 | labelVec[start_idx:end_idx] = class_labels 84 | # update the starting point 85 | start_idx += inputs.shape[0] 86 | return predVec, labelVec 87 | 88 | def main(): 89 | """ 90 | Goal: process images by file lists, evaluating the datasize with different model size 91 | Version: 5.0 92 | """ 93 | print('PyTorch Version: ', torch.__version__) 94 | print('Torchvision Version: ', torchvision.__version__) 95 | now = datetime.datetime.now() 96 | random.seed(1) 97 | torch.manual_seed(1) 98 | torch.backends.cudnn.deterministic = True 99 | torch.backends.cudnn.benchmark = False 100 | 101 | args = arg_process() 102 | runFileName = sys.argv[0].split('.')[0] 103 | # need to modify according to the enviroment 104 | version = args.dataversion 105 | gpuid = args.gpuid 106 | model_name = args.modelname 107 | num_epochs = args.epochs 108 | lr = args.learningrate 109 | batch_size = args.batchsize 110 | model_version = args.modelversion 111 | feature_map = args.featuremap 112 | loss_function = args.lossfunction 113 | pool_size = args.poolsize 114 | classes = 3 115 | 116 | # logPath = os.path.join('result', model_name+'_log.txt') 117 | logPath = os.path.join('result', runFileName + '_log_' + 'v{}'.format(args.dataversion) + '.txt') 118 | # logPath = os.path.join('result', runFileName+'_log_'+'v{}'.format(version)+'.txt') 119 | # resultPath = os.path.join('result', 'result_'+'v{}'.format(version)+'.pt') 120 | data_transforms = transforms.Compose([ 121 | transforms.ToTensor() 122 | ]) 123 | # move to GPU 124 | device = torch.device(gpuid if torch.cuda.is_available() else 'cpu') 125 | # obtian the subject information in LOSO 126 | verFolder = 'v_{}'.format(version) 127 | filePath = os.path.join('data', 'MEGC2019', verFolder, 'video442subName.txt') 128 | subjects = [] 129 | with open(filePath, 'r') as f: 130 | for textline in f: 131 | texts = textline.strip('\n') 132 | subjects.append(texts) 133 | # predicted and label vectors 134 | preds_db = {} 135 | preds_db['casme2'] = torch.tensor([]) 136 | preds_db['smic'] = torch.tensor([]) 137 | preds_db['samm'] = torch.tensor([]) 138 | preds_db['all'] = torch.tensor([]) 139 | labels_db = {} 140 | labels_db['casme2'] = torch.tensor([]) 141 | labels_db['smic'] = torch.tensor([]) 142 | labels_db['samm'] = torch.tensor([]) 143 | labels_db['all'] = torch.tensor([]) 144 | # open the log file and begin to record 145 | log_f = open(logPath,'a') 146 | log_f.write('{}\n'.format(now)) 147 | log_f.write('-' * 80 + '\n') 148 | log_f.write('-' * 80 + '\n') 149 | log_f.write('Results:\n') 150 | time_s = time.time() 151 | for subject in subjects: 152 | print('Subject Name: {}'.format(subject)) 153 | print('---------------------------') 154 | # random.seed(1) 155 | # setup a dataloader for training 156 | imgDir = os.path.join('data', 'MEGC2019', verFolder, '{}_train.txt'.format(subject)) 157 | image_db_train = MEGC2019(imgList=imgDir,transform=data_transforms) 158 | dataloader_train = torch.utils.data.DataLoader(image_db_train, batch_size=batch_size, shuffle=True, num_workers=1) 159 | # Initialize the model 160 | print('\tCreating deep model....') 161 | if model_name == 'rcn_a': 162 | model_ft = MeNets.RCN_A(num_input=3, featuremaps=feature_map, num_classes=classes, num_layers=1, pool_size=pool_size, model_version=model_version) 163 | elif model_name == 'rcn_s': 164 | model_ft = MeNets.RCN_S(num_input=3, featuremaps=feature_map, num_classes=classes, num_layers=1, pool_size=pool_size) 165 | elif model_name == 'rcn_w': 166 | model_ft = MeNets.RCN_W(num_input=3, featuremaps=feature_map, num_classes=classes, num_layers=1, pool_size=pool_size) 167 | elif model_name == 'rcn_p': 168 | model_ft = MeNets.RCN_P(num_input=3, featuremaps=feature_map, num_classes=classes, num_layers=1, pool_size=pool_size) 169 | elif model_name == 'rcn_c': 170 | model_ft = MeNets.RCN_C(num_input=3, featuremaps=feature_map, num_classes=classes, num_layers=1, pool_size=pool_size) 171 | elif model_name == 'rcn_f': 172 | model_ft = MeNets.RCN_F(num_input=3, featuremaps=feature_map, num_classes=classes, num_layers=1, pool_size=pool_size) 173 | params_to_update = model_ft.parameters() 174 | optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum=0.9) 175 | # optimizer_ft = optim.Adam(params_to_update, lr=lr) 176 | if loss_function == 'crossentropy': 177 | criterion = nn.CrossEntropyLoss() 178 | elif loss_function == 'focal': 179 | criterion = LossFunctions.FocalLoss(class_num=classes,device=device) 180 | elif loss_function == 'balanced': 181 | criterion = LossFunctions.BalancedLoss(class_num=classes, device=device) 182 | elif loss_function == 'cosine': 183 | criterion = LossFunctions.CosineLoss() 184 | model_ft = model_ft.to(device) # from cpu to gpu 185 | # Train and evaluate 186 | model_ft = train_model(model_ft, dataloader_train, criterion, optimizer_ft, device, num_epochs=num_epochs) 187 | # torch.save(model_ft, os.path.join('data', 'model_s{}.pth').format(subject)) 188 | 189 | # Test model 190 | imgDir = os.path.join('data', 'MEGC2019', verFolder, '{}_test.txt'.format(subject)) 191 | image_db_test = MEGC2019(imgList=imgDir,transform=data_transforms) 192 | dataloaders_test = torch.utils.data.DataLoader(image_db_test, batch_size=batch_size, shuffle=False, 193 | num_workers=1) 194 | 195 | preds, labels = test_model(model_ft, dataloaders_test, device) 196 | acc = torch.sum(preds == labels).double()/len(preds) 197 | print('\tSubject {} has the accuracy:{:.4f}\n'.format(subject,acc)) 198 | print('---------------------------\n') 199 | log_f.write('\tSubject {} has the accuracy:{:.4f}\n'.format(subject,acc)) 200 | 201 | # saving the subject results 202 | preds_db['all'] = torch.cat((preds_db['all'], preds), 0) 203 | labels_db['all'] = torch.cat((labels_db['all'], labels), 0) 204 | if subject.find('sub')!= -1: 205 | preds_db['casme2'] = torch.cat((preds_db['casme2'], preds), 0) 206 | labels_db['casme2'] = torch.cat((labels_db['casme2'], labels), 0) 207 | else: 208 | if subject.find('s')!= -1: 209 | preds_db['smic'] = torch.cat((preds_db['smic'], preds), 0) 210 | labels_db['smic'] = torch.cat((labels_db['smic'], labels), 0) 211 | else: 212 | preds_db['samm'] = torch.cat((preds_db['samm'], preds), 0) 213 | labels_db['samm'] = torch.cat((labels_db['samm'], labels), 0) 214 | time_e = time.time() 215 | hours, rem = divmod(time_e-time_s,3600) 216 | miniutes, seconds = divmod(rem,60) 217 | # evaluate all data 218 | eval_acc = metrics.accuracy() 219 | eval_f1 = metrics.f1score() 220 | acc_w, acc_uw = eval_acc.eval(preds_db['all'], labels_db['all']) 221 | f1_w, f1_uw = eval_f1.eval(preds_db['all'], labels_db['all']) 222 | print('\nThe dataset has the UAR and UF1:{:.4f} and {:.4f}'.format(acc_uw, f1_uw)) 223 | log_f.write('\nOverall:\n\tthe UAR and UF1 of all data are {:.4f} and {:.4f}\n'.format(acc_uw, f1_uw)) 224 | # casme2 225 | if preds_db['casme2'].nelement() != 0: 226 | acc_w, acc_uw = eval_acc.eval(preds_db['casme2'], labels_db['casme2']) 227 | f1_w, f1_uw = eval_f1.eval(preds_db['casme2'], labels_db['casme2']) 228 | print('\nThe casme2 dataset has the UAR and UF1:{:.4f} and {:.4f}'.format(acc_uw, f1_uw)) 229 | log_f.write('\tthe UAR and UF1 of casme2 are {:.4f} and {:.4f}\n'.format(acc_uw, f1_uw)) 230 | # smic 231 | if preds_db['smic'].nelement() != 0: 232 | acc_w, acc_uw = eval_acc.eval(preds_db['smic'], labels_db['smic']) 233 | f1_w, f1_uw = eval_f1.eval(preds_db['smic'], labels_db['smic']) 234 | print('\nThe smic dataset has the UAR and UF1:{:.4f} and {:.4f}'.format(acc_uw, f1_uw)) 235 | log_f.write('\tthe UAR and UF1 of smic are {:.4f} and {:.4f}\n'.format(acc_uw, f1_uw)) 236 | # samm 237 | if preds_db['samm'].nelement() != 0: 238 | acc_w, acc_uw = eval_acc.eval(preds_db['samm'], labels_db['samm']) 239 | f1_w, f1_uw = eval_f1.eval(preds_db['samm'], labels_db['samm']) 240 | print('\nThe samm dataset has the UAR and UF1:{:.4f} and {:.4f}'.format(acc_uw, f1_uw)) 241 | log_f.write('\tthe UAR and UF1 of samm are {:.4f} and {:.4f}\n'.format(acc_uw, f1_uw)) 242 | # writing parameters into log file 243 | print('\tNetname:{}, Dataversion:{}\n\tLearning rate:{}, Epochs:{}, Batchsize:{}.'.format(model_name,version,lr,num_epochs,batch_size)) 244 | print('\tElapsed time: {:0>2}:{:0>2}:{:05.2f}'.format(int(hours),int(miniutes),seconds)) 245 | log_f.write('\nOverall:\n\tthe weighted and unweighted accuracy of all data are {:.4f} and {:.4f}\n'.format(acc_w,acc_uw)) 246 | log_f.write('\nSetting:\tNetname:{}, Dataversion:{}\n\tLearning rate:{}, Epochs:{}, Batchsize:{}.\n'.format(model_name, 247 | version, 248 | lr, 249 | num_epochs, 250 | batch_size)) 251 | # # save subject's results 252 | # torch.save({ 253 | # 'predicts':preds_db, 254 | # 'labels':labels_db, 255 | # 'weight_acc':acc_w, 256 | # 'unweight_acc':acc_uw 257 | # },resultPath) 258 | log_f.write('-' * 80 + '\n') 259 | log_f.write('-' * 80 + '\n') 260 | log_f.write('\n') 261 | log_f.close() 262 | 263 | if __name__ == '__main__': 264 | main() -------------------------------------------------------------------------------- /PrepareData_LOSO_CD.py: -------------------------------------------------------------------------------- 1 | import os, random 2 | import numpy as np 3 | 4 | dbtype_dict = {'casme2':0, 'smic':1, 'samm':2} 5 | 6 | def main(): 7 | version = 67 # 0, 1, 2, 4 8 | verFolder = 'v_{}'.format(version) 9 | alphas = range(0,1) 10 | 11 | dataDir = os.path.join('data', 'MEGC2019', verFolder) 12 | filePath = os.path.join('dataset', 'megc_meta.csv') 13 | meta_dict = {'dbtype':[],'subject':[],'filename':[],'emotion':[],'subid':[],'dbid':[]} 14 | with open(filePath,'r') as f: 15 | for textline in f: 16 | texts = textline.strip('\n').split(',') 17 | meta_dict['dbtype'].append(texts[0]) 18 | meta_dict['subject'].append(texts[1]) 19 | meta_dict['filename'].append(texts[2]) 20 | meta_dict['emotion'].append(int(texts[3])) 21 | meta_dict['subid'].append(int(texts[4])) 22 | meta_dict['dbid'].append(int(texts[5])) 23 | subjects = list(set(meta_dict['subid'])) 24 | sampleNum = len(meta_dict['dbtype']) 25 | for subject in subjects: 26 | idx = meta_dict['subid'].index(subject) 27 | subjectName = meta_dict['subject'][idx] 28 | # open the training/val/test list file 29 | filePath = os.path.join('data','MEGC2019', verFolder, '{}_train.txt'.format(subjectName)) 30 | train_f = open(filePath,'w') 31 | filePath = os.path.join('data','MEGC2019', verFolder, '{}_test.txt'.format(subjectName)) 32 | test_f = open(filePath,'w') 33 | # traverse each item, totally 442 34 | for i in range(0,sampleNum): 35 | for alpha in alphas: 36 | fileDir = os.path.join(dataDir, 'flow_alpha{}'.format(alpha)) 37 | fileName = '{}_{}_{}.png'.format(meta_dict['dbtype'][i], meta_dict['subject'][i], 38 | meta_dict['filename'][i]) 39 | filePath = os.path.join(fileDir, fileName) 40 | if int(meta_dict['subid'][i]) == int(subject): 41 | test_f.write('{} {} {}\n'.format(filePath,meta_dict['emotion'][i],meta_dict['dbid'][i])) 42 | else: 43 | train_f.write('{} {} {}\n'.format(filePath,meta_dict['emotion'][i],meta_dict['dbid'][i])) 44 | print('The {}-th subject: {}.'.format(subject,subjectName)) 45 | train_f.close() 46 | test_f.close() 47 | 48 | if __name__ == '__main__': 49 | main() -------------------------------------------------------------------------------- /RCNs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RecConv(nn.Module): 5 | 6 | def __init__(self, inplanes, planes): 7 | super(RecConv, self).__init__() 8 | self.ffconv = nn.Sequential( 9 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0), 10 | nn.BatchNorm2d(planes), 11 | nn.ReLU(inplace=True) 12 | ) 13 | self.rrconv = nn.Sequential( 14 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1), 15 | nn.BatchNorm2d(planes), 16 | nn.ReLU(inplace=True) 17 | ) 18 | self.downsample = nn.Sequential( 19 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 20 | nn.Dropout() 21 | ) 22 | 23 | def forward(self, x): 24 | y = self.ffconv(x) 25 | y = self.rrconv(x + y) 26 | y = self.rrconv(x + y) 27 | # y = self.rrconv(x + y) 28 | out = self.downsample (y) 29 | 30 | return out 31 | 32 | 33 | class RecNet(nn.Module): 34 | """Recurent networks with single output 35 | """ 36 | def __init__(self, num_input, featuremaps, num_classes, num_layers): 37 | super(RecNet, self).__init__() 38 | 39 | self.conv1 = nn.Sequential( 40 | nn.Conv2d(num_input, featuremaps, kernel_size=5, stride=1, padding=0), 41 | nn.BatchNorm2d(featuremaps), 42 | nn.ReLU(inplace=True), 43 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 44 | nn.Dropout(), 45 | ) 46 | self.rcls = self._make_layer(RecConv, featuremaps, num_layers) 47 | self.avgpool = nn.AdaptiveAvgPool2d((5, 5)) 48 | self.classifier = nn.Linear(5*5*featuremaps, num_classes) 49 | 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 53 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 54 | nn.init.constant_(m.weight, 1) 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, nn.Linear): 57 | nn.init.normal_(m.weight, 0, 0.01) 58 | nn.init.constant_(m.bias, 0) 59 | 60 | def _make_layer(self, block, planes, blocks): 61 | layers = [] 62 | for _ in range(0, blocks): 63 | layers.append(block(planes, planes)) 64 | 65 | return nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | x = self.conv1(x) 69 | x = self.rcls(x) 70 | x = self.avgpool(x) 71 | x = torch.flatten(x, 1) 72 | x = self.classifier(x) 73 | 74 | return x 75 | 76 | class RecNet_v2(nn.Module): 77 | """Recurent networks with shallower architecture 78 | """ 79 | def __init__(self, num_input, featuremaps, num_classes): 80 | super(RecNet_v2, self).__init__() 81 | 82 | self.conv1 = nn.Sequential( 83 | nn.Conv2d(num_input, featuremaps, kernel_size=5, stride=1, padding=0), 84 | nn.BatchNorm2d(featuremaps), 85 | nn.ReLU(inplace=True), 86 | nn.MaxPool2d(kernel_size=4, stride=4, padding=0), 87 | nn.Dropout(), 88 | ) 89 | self.avgpool = nn.AdaptiveAvgPool2d((5, 5)) 90 | self.classifier = nn.Linear(5*5*featuremaps, num_classes) 91 | 92 | for m in self.modules(): 93 | if isinstance(m, nn.Conv2d): 94 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 95 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 96 | nn.init.constant_(m.weight, 1) 97 | nn.init.constant_(m.bias, 0) 98 | elif isinstance(m, nn.Linear): 99 | nn.init.normal_(m.weight, 0, 0.01) 100 | nn.init.constant_(m.bias, 0) 101 | 102 | def forward(self, x): 103 | x = self.conv1(x) 104 | x = self.avgpool(x) 105 | x = torch.flatten(x, 1) 106 | x = self.classifier(x) 107 | 108 | return x 109 | 110 | class RecNet_v3(nn.Module): 111 | """Recurent networks with multiple outputs 112 | """ 113 | def __init__(self, num_input, featuremaps, num_classes, num_layers): 114 | super(RecNet_v3, self).__init__() 115 | 116 | self.conv1 = nn.Sequential( 117 | nn.Conv2d(num_input, featuremaps, kernel_size=5, stride=1, padding=0), 118 | nn.BatchNorm2d(featuremaps), 119 | nn.ReLU(inplace=True), 120 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 121 | nn.Dropout(), 122 | ) 123 | self.rcls = self._make_layer(RecConv, featuremaps, num_layers) 124 | self.avgpool = nn.AdaptiveAvgPool2d((5, 5)) 125 | self.classifier = nn.Linear(5*5*featuremaps, num_classes[0]) 126 | self.classifier_aux1 = nn.Linear(5*5*featuremaps, num_classes[1]) 127 | 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 131 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 132 | nn.init.constant_(m.weight, 1) 133 | nn.init.constant_(m.bias, 0) 134 | elif isinstance(m, nn.Linear): 135 | nn.init.normal_(m.weight, 0, 0.01) 136 | nn.init.constant_(m.bias, 0) 137 | 138 | def _make_layer(self, block, planes, blocks): 139 | layers = [] 140 | for _ in range(0, blocks): 141 | layers.append(block(planes, planes)) 142 | 143 | return nn.Sequential(*layers) 144 | 145 | def forward(self, x): 146 | x = self.conv1(x) 147 | x = self.rcls(x) 148 | x = self.avgpool(x) 149 | x = torch.flatten(x, 1) 150 | y = self.classifier(x) 151 | y_1 = self.classifier_aux1(x) 152 | 153 | return y, y_1 154 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ParameterFreeRCNs-MicroExpressionRec 2 | Recurrent convolutional networks with parameter free modules for composite-database micro-expression recognition 3 | 4 | #### Descriptions 5 | These codes are used for micro-expression recognition on composite datasets (e.g., MEGC2019). The methods can be accessed by the paper "Revealing the Invisible With Model and Data Shrinking for Composite-Database Micro-Expression Recognition, IEEE TIP2020", which includes the RCN-A, RCN-S, RCN-W, RCN-P, RCN-C and RCN-F. 6 | 7 | #### Dependencies 8 | The code was written in Python 3.6, and tested on Windows 10 and CentOS 7. 9 | - Pytorch: 1.1 or newer 10 | - Numpy: 1.16.3 or newer 11 | - Scikit-learn: 0.22.1 or newer 12 | 13 | #### Instructions 14 | 1. The optical flow should be extracted by your own tools before training the deep model. 15 | 2. The data can be prepared by the script "PrepareData_LOSO_CD.py". 16 | 3. At last, various models can be accessed by using the commond "--modelname". For training, the command like "python ModelEval_Final.py --dataset smic --dataversion 1 --epochs 20 --learningrate 0.0005 --modelname rcn_a --batchsize 64 --featuremap 32 --poolsize 5 --lossfunction crossentropy" can be used. 17 | --------------------------------------------------------------------------------