├── README.md ├── dual_protonet.py ├── feature_extractor.py └── meta_template.py /README.md: -------------------------------------------------------------------------------- 1 | # SPNet 2 | This is the code for the paper Siamese-Prototype Netwrok for Few-Shot Remote Sensing Image Scene Classification. 3 | More details will be updated later... 4 | -------------------------------------------------------------------------------- /dual_protonet.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/jakesnell/prototypical-networks 2 | 3 | import backbone 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from methods.meta_template import MetaTemplate 10 | 11 | 12 | class DualProtoNet(MetaTemplate): 13 | def __init__(self, model_func, n_way, n_support): 14 | super(DualProtoNet, self).__init__(model_func, n_way, n_support) 15 | self.loss_fn = nn.CrossEntropyLoss() 16 | 17 | def set_forward(self, x, is_feature=False): 18 | z_support, z_query = self.parse_feature(x, is_feature) # 5*5*512*1*1, 5*16*512*1*1 19 | 20 | z_support_lst = [z_support[i] for i in range(z_support.size(0))] 21 | z_support = z_support.contiguous() 22 | z_proto = z_support.view(self.n_way, self.n_support, -1).mean(1) # 5*512(way,channel),average of 5 supports in each way 23 | z_query = z_query.contiguous().view(self.n_way * self.n_query, -1) # 80*512 24 | 25 | scores2 = [] 26 | for z_suppor in z_support_lst: # 5*512 27 | dist2 = cos_dist(z_suppor, z_proto) 28 | scores2.append(dist2) 29 | scores2 = torch.cat(scores2, dim=0).view(5, 5).cuda() # 5*5 30 | #print(scores2) 31 | scores1 = cos_dist1(z_query, z_proto) # 80*5 32 | 33 | scores3 = [] 34 | for z_suppor in z_support_lst: # 5*512 35 | scores1_proto = scores1.mean(0) 36 | dist3 = cos_dist2(z_suppor, scores1_proto) 37 | scores3.append(dist3) 38 | scores3 = torch.cat(scores3, dim=0).view(5, 5).cuda() 39 | 40 | # print(z_query,z_proto) 41 | return scores1, scores2, scores3 42 | 43 | def set_forward_loss(self, x): 44 | y_query = torch.from_numpy( 45 | np.repeat(range(self.n_way), self.n_query)) # (0,0,...,0,1,1,...,1,...,4,4,...,4) 80*1 46 | y_query = Variable(y_query.cuda()) 47 | y_support = torch.from_numpy(np.array(range(self.n_way))) 48 | y_support = Variable(y_support.cuda()) 49 | scores1, scores2, scores3 = self.set_forward(x) 50 | loss1 = self.loss_fn(scores1.float(), y_query.long()) # 80*5, 80*1 51 | loss2 = self.loss_fn(scores2.float(), y_support.long()) # 5*5, 5*1 52 | loss3 = self.loss_fn(scores3.float(), y_support.long()) # 5*5, 5*1 53 | return loss1, loss2, loss3 54 | 55 | 56 | def cos_dist(x, y): 57 | # x:n*d(5*512) 58 | # y:n*d(5*512) #proto 59 | n = x.size(0) 60 | d = x.size(1) 61 | assert d == y.size(1) 62 | #assert n == y.size(0) 63 | ''' 64 | original code to calculate euclidean distance: 65 | x = x.unsqueeze(1).expand(n, n, d) 66 | y = y.unsqueeze(0).expand(n, n, d) 67 | dist = torch.pow(x - y, 2).sum(2).mean(0) 68 | ''' 69 | 70 | ####5shot: 71 | 72 | x_lst = [x[i] for i in range(n)] 73 | cos_lst = [] 74 | for x_c in x_lst: 75 | x_c = x_c.unsqueeze(0) #1*512 76 | cos = F.cosine_similarity(x_c, y, dim=1) * 50 77 | cos_lst.append(cos) 78 | cosi = torch.cat(cos_lst, dim=0).view(5, 5).mean(0) 79 | ''' 80 | # 1shot: 81 | cosi = F.cosine_similarity(x, y, dim=1) * 50 82 | ''' 83 | #print(cosi) 84 | return cosi 85 | 86 | 87 | def cos_dist1(x, y): 88 | # x: N x D 89 | # y: M x D 90 | n = x.size(0) # 80 91 | m = y.size(0) # 5 92 | d = x.size(1) # 512 93 | assert d == y.size(1) 94 | 95 | x_lst = [x[i] for i in range(n)] 96 | cos_lst = [] 97 | for x_c in x_lst: 98 | x_c = x_c.unsqueeze(0) #1*512 99 | cos = F.cosine_similarity(x_c, y, dim=1) * 50 100 | cos_lst.append(cos) 101 | cosi = torch.cat(cos_lst, dim=0).view(80, 5) 102 | return cosi # 80*5 103 | 104 | 105 | def cos_dist2(x, y): 106 | # x: 5*512 107 | # y: 5 108 | n = x.size(0) # 5 109 | d = y.size() # 5 110 | 111 | y = y.view(5,1).cpu().detach().numpy() 112 | y = np.repeat(y, 512, axis=1) #5*512 113 | y = torch.from_numpy(y).cuda() 114 | ####5shot: 115 | 116 | x_lst = [x[i] for i in range(n)] 117 | cos_lst = [] 118 | for x_c in x_lst: 119 | x_c = x_c.unsqueeze(0) #1*512 120 | cos = F.cosine_similarity(x_c, y, dim=1) * 50 121 | cos_lst.append(cos) 122 | cosi = torch.cat(cos_lst, dim=0).view(5, 5).mean(0) 123 | ''' 124 | ## 1shot: 125 | cosi = F.cosine_similarity(x, y, dim=1) * 50 126 | ''' 127 | return cosi #5 128 | -------------------------------------------------------------------------------- /feature_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import math 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | def init_layer(L): 9 | # Initialization using fan-in 10 | if isinstance(L, nn.Conv2d): 11 | n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels 12 | L.weight.data.normal_(0,math.sqrt(2.0/float(n))) 13 | elif isinstance(L, nn.BatchNorm2d): 14 | L.weight.data.fill_(1) 15 | L.bias.data.fill_(0) 16 | 17 | 18 | class Flatten(nn.Module): 19 | def __init__(self): 20 | super(Flatten, self).__init__() 21 | 22 | def forward(self, x): 23 | return x.view(x.size(0), -1) 24 | 25 | 26 | class SimpleBlock(nn.Module): 27 | maml = False #Default 28 | def __init__(self, indim, outdim, half_res): 29 | super(SimpleBlock, self).__init__() 30 | self.indim = indim 31 | self.outdim = outdim 32 | if self.maml: 33 | self.C1 = Conv2d_fw(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 34 | self.BN1 = BatchNorm2d_fw(outdim) 35 | self.C2 = Conv2d_fw(outdim, outdim,kernel_size=3, padding=1,bias=False) 36 | self.BN2 = BatchNorm2d_fw(outdim) 37 | else: 38 | self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 39 | self.BN1 = nn.BatchNorm2d(outdim) 40 | self.C2 = nn.Conv2d(outdim, outdim,kernel_size=3, padding=1,bias=False) 41 | self.BN2 = nn.BatchNorm2d(outdim) 42 | self.relu1 = nn.ReLU(inplace=True) 43 | self.relu2 = nn.ReLU(inplace=True) 44 | 45 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2] 46 | 47 | self.half_res = half_res 48 | 49 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 50 | if indim!=outdim: 51 | if self.maml: 52 | self.shortcut = Conv2d_fw(indim, outdim, 1, 2 if half_res else 1, bias=False) 53 | self.BNshortcut = BatchNorm2d_fw(outdim) 54 | else: 55 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False) 56 | self.BNshortcut = nn.BatchNorm2d(outdim) 57 | 58 | self.parametrized_layers.append(self.shortcut) 59 | self.parametrized_layers.append(self.BNshortcut) 60 | self.shortcut_type = '1x1' 61 | else: 62 | self.shortcut_type = 'identity' 63 | 64 | for layer in self.parametrized_layers: 65 | init_layer(layer) 66 | 67 | def forward(self, x): 68 | out = self.C1(x) 69 | out = self.BN1(out) 70 | out = self.relu1(out) 71 | out = self.C2(out) 72 | out = self.BN2(out) 73 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x)) 74 | out = out + short_out 75 | out = self.relu2(out) 76 | return out 77 | 78 | 79 | class ResNet(nn.Module): 80 | maml = False #Default 81 | def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten = True): 82 | # list_of_num_layers specifies number of layers in each stage 83 | # list_of_out_dims specifies number of output channel for each stage 84 | super(ResNet,self).__init__() 85 | assert len(list_of_num_layers)==4, 'Can have only four stages' 86 | if self.maml: 87 | conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3, 88 | bias=False) 89 | bn1 = BatchNorm2d_fw(64) 90 | else: 91 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 92 | bias=False) 93 | bn1 = nn.BatchNorm2d(64) 94 | 95 | relu = nn.ReLU() 96 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 97 | 98 | init_layer(conv1) 99 | init_layer(bn1) 100 | 101 | 102 | trunk = [conv1, bn1, relu, pool1] 103 | 104 | indim = 64 105 | for i in range(4): 106 | 107 | for j in range(list_of_num_layers[i]): 108 | half_res = (i>=1) and (j==0) 109 | B = block(indim, list_of_out_dims[i], half_res) 110 | trunk.append(B) 111 | indim = list_of_out_dims[i] 112 | 113 | if flatten: 114 | avgpool = nn.AvgPool2d(7) 115 | trunk.append(avgpool) 116 | trunk.append(Flatten()) 117 | self.final_feat_dim = indim 118 | else: 119 | self.final_feat_dim = [ indim, 7, 7] 120 | 121 | self.trunk = nn.Sequential(*trunk) 122 | 123 | def forward(self,x): 124 | out = self.trunk(x) 125 | return out 126 | 127 | def ResNet18( flatten = True): 128 | return ResNet(SimpleBlock, [2,2,2,2],[64,128,256,512], flatten) 129 | -------------------------------------------------------------------------------- /meta_template.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/wyharveychen/CloserLookFewShot 2 | 3 | 4 | import backbone 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import torch.nn.functional as F 10 | import utils 11 | from abc import abstractmethod 12 | import cv2 13 | 14 | class MetaTemplate(nn.Module): 15 | def __init__(self, model_func, n_way, n_support, change_way = False): 16 | super(MetaTemplate, self).__init__() 17 | self.n_way = n_way 18 | self.n_support = n_support 19 | self.n_query = -1 #(change depends on input) 20 | self.feature = model_func() 21 | self.feat_dim = self.feature.final_feat_dim 22 | 23 | ''' 24 | @abstractmethod:Abstract a base class and specify which methods to use, but only abstract methods, do not implement 25 | functions, the class can only be inherited, not instantiated, but subclasses must implement the methods. 26 | ''' 27 | @abstractmethod 28 | def set_forward(self,x,is_feature): 29 | pass 30 | 31 | @abstractmethod 32 | def set_forward_loss(self, x): 33 | pass 34 | 35 | def forward(self,x): 36 | out = self.feature.forward(x) 37 | return out 38 | 39 | def parse_feature(self,x,is_feature): 40 | x = Variable(x.cuda()) 41 | if is_feature: 42 | z_all = x 43 | else: 44 | x = x.contiguous().view( self.n_way * (self.n_support + self.n_query), *x.size()[2:]) #images:(5*21)*3*224*224 45 | z_all = self.feature.forward(x) #features:(5*21)*512*1*1 46 | c = z_all.shape[1] 47 | h = z_all.shape[-2] 48 | w = z_all.shape[-1] 49 | z_all = z_all.view( self.n_way, self.n_support + self.n_query, c) #5*21*512*1*1 50 | z_support = z_all[:, :self.n_support] #5*5*512*1*1 51 | z_query = z_all[:, self.n_support:] #5*16*512*1*1 52 | return z_support, z_query 53 | 54 | def correct(self, x): 55 | scores1, scores2, scores3 = self.set_forward(x) 56 | y_query = np.repeat(range( self.n_way ), self.n_query ) 57 | 58 | topk_scores, topk_labels = scores1.data.topk(1, 1, True, True) 59 | topk_ind = topk_labels.cpu().numpy() 60 | top1_correct = np.sum(topk_ind[:,0] == y_query) 61 | return float(top1_correct), len(y_query) 62 | 63 | def train_loop(self, epoch, train_loader, optimizer, scheduler, const1, const2): 64 | print_freq = 10 65 | avg_loss=0 66 | for i, (x,_ ) in enumerate(train_loader): 67 | self.n_query = x.size(1) - self.n_support 68 | optimizer.zero_grad() 69 | ###### add calibration loss ######################################################################## 70 | loss1, loss2, loss3 = self.set_forward_loss( x ) 71 | loss = loss1 + const1 * loss2 + const2 * loss3 72 | loss.backward() 73 | optimizer.step() 74 | avg_loss = avg_loss+loss.item() 75 | 76 | if i % print_freq==0: 77 | print('Epoch [%d], Batch [%d/%d], Loss: %.6f, lr: %f' % (epoch, i, len(train_loader), avg_loss/float(i+1), scheduler.get_lr()[0])) 78 | scheduler.step(epoch) 79 | def test_loop(self, test_loader, record = None): 80 | correct =0 81 | count = 0 82 | acc_all = [] 83 | 84 | iter_num = len(test_loader) 85 | for i, (x,_) in enumerate(test_loader): 86 | self.n_query = x.size(1) - self.n_support 87 | correct_this, count_this = self.correct(x) 88 | acc_all.append(correct_this/ count_this*100 ) 89 | 90 | acc_all = np.asarray(acc_all) 91 | acc_mean = np.mean(acc_all) 92 | acc_std = np.std(acc_all) 93 | print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num))) 94 | 95 | return acc_mean 96 | --------------------------------------------------------------------------------