├── imgs ├── model.png └── saff.png ├── .gitignore ├── results ├── .gitkeep └── feature_map.mat ├── .idea ├── .gitignore ├── vcs.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── modules.xml ├── misc.xml ├── 小样本分类 2k.iml └── deployment.xml ├── requirements.txt ├── run.sh ├── dataset ├── util.py └── dataLoader.py ├── run.py ├── README.md ├── Trainer.py └── network.py /imgs/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh-hike/SAFF/HEAD/imgs/model.png -------------------------------------------------------------------------------- /imgs/saff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh-hike/SAFF/HEAD/imgs/saff.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .DS_Store 3 | __pycache__ 4 | *.ipynb 5 | *.mat -------------------------------------------------------------------------------- /results/.gitkeep: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep -------------------------------------------------------------------------------- /results/feature_map.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh-hike/SAFF/HEAD/results/feature_map.mat -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.5 2 | scikit_learn==1.0.2 3 | scipy==1.5.0 4 | torch==1.11.0 5 | torchvision==0.12.0 6 | tqdm==4.49.0 7 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | python run.py --data_path /hy-tmp/data --extract --dataset SAR --ratio 0.8 4 | 5 | 6 | python run.py --data_path /hy-tmp/data --train --ratio 0.8 --dataset SAR 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /dataset/util.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import transforms as T 2 | 3 | 4 | def get_default_aug(dataset): 5 | size = 256 6 | if dataset == 'Aerial': 7 | size = 600 8 | aug = T.Compose([ 9 | T.Resize((size, size)), 10 | T.ToTensor(), 11 | T.Normalize(mean=[0, 0, 0], std=[0.5, 0.5, 0.5]) 12 | ]) 13 | 14 | return aug 15 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/小样本分类 2k.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /dataset/dataLoader.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import ImageFolder 2 | import os 3 | from torch.utils.data import DataLoader 4 | from dataset.util import get_default_aug 5 | 6 | 7 | class DL: 8 | def __init__(self, args): 9 | path = args.data_path 10 | aug = get_default_aug(args.dataset) 11 | if args.dataset == 'NWPU': 12 | path = os.path.join(path, 'NWPU-RESISC45') 13 | elif args.dataset == 'UC': 14 | path = os.path.join(path, 'UCMerced_LandUse/Images') 15 | elif args.dataset == 'SAR': 16 | path = os.path.join(path, 'SAR/') 17 | 18 | data = ImageFolder(path, transform=aug) 19 | self.dl = DataLoader(data, batch_size=args.batch_size, shuffle=True, drop_last=True) 20 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from Trainer import Trainer 3 | import torch 4 | 5 | parser = argparse.ArgumentParser("classifier") 6 | 7 | parser.add_argument('--epochs', type=int, help="训练轮次", default=100) 8 | parser.add_argument('--batch_size', type=int, help="批次大小", default=200) 9 | parser.add_argument('--dataset', type=str, choices=['NWPU', 'UC', 'SAR'], default='NWPU') 10 | parser.add_argument('--data_path', type=str, help="数据集所在路径", default='/users/zhhike/desktop/dataset/') 11 | parser.add_argument('--model', type=str, choices=['vgg16'], default='vgg16') 12 | parser.add_argument('--K', type=int, help="降为参数", default=784) 13 | parser.add_argument('--ratio', type=float, help="训练比例", default=0.1) 14 | parser.add_argument('--extract', action='store_true', help="是否进行特征提取") 15 | parser.add_argument('--train', action='store_true', help="是否进行特征训练") 16 | 17 | args = parser.parse_args() 18 | 19 | if __name__ == "__main__": 20 | trainer = Trainer(args) 21 | 22 | with torch.no_grad(): 23 | if args.extract: 24 | trainer.feature_extract() 25 | if args.train: 26 | trainer.train() 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Attention-Based Deep Feature Fusion for Remote Sensing Scene Classification 2 | Use vgg16 and SAFF for small sample classification from the [paper](https://paperswithcode.com/paper/self-attention-based-deep-feature-fusion-for). 3 | 4 | ## Introduction 5 | * Extract dataset features using pretrained vgg16 6 |

7 | 8 | * SAFF converts features into 1D tensor 9 | 10 |

11 | 12 | ## Environmental preparation 13 | ```bash 14 | conda create -n zh python=3.9 15 | conda activate zh 16 | python3 -m pip install --upgrade pip 17 | pip3 install -r requirements.txt 18 | ``` 19 | 20 | ## Run 21 | If your dataset is at path */hy-tmp/data* 22 | Suppose you want to train on the *UC* dataset. 23 | 24 | * Feature extraction 25 | ``` 26 | python run.py 27 | --data_path /hy-tmp/data 28 | --extract 29 | --dataset UC 30 | ``` 31 | 32 | * Train & verify 33 | 34 | ``` 35 | python run.py 36 | --data_path /hy-tmp/data 37 | --train 38 | --dataset UC 39 | --ratio 0.8 40 | ``` 41 | 42 | ## Experimental results 43 | 44 | | dataset | train_ratio | acc | 45 | |:-------:|:-----------:|:-----:| 46 | | NWPU | 0.1 | 66.49 | 47 | | NWPU | 0.2 | 73.13 | 48 | | UC | 0.8 | 92.5 | 49 | | SAR | 0.8 | 89.8 | 50 | -------------------------------------------------------------------------------- /Trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataset.dataLoader import DL 3 | from network import Model 4 | import torch.nn as nn 5 | from sklearn.svm import SVC 6 | import numpy as np 7 | from sklearn.metrics import accuracy_score 8 | from sklearn.decomposition import PCA 9 | from sklearn import manifold 10 | from sklearn.model_selection import train_test_split 11 | import scipy.io as io 12 | from tqdm import tqdm 13 | 14 | 15 | class Trainer: 16 | def __init__(self, args): 17 | self.args = args 18 | self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 19 | self._init_data() 20 | self._init_model() 21 | 22 | def _init_model(self): 23 | self.net = Model(self.args).to(self.device) 24 | self.opt = torch.optim.Adam(self.net.parameters(), ) 25 | self.svm = SVC(kernel='rbf') 26 | self.pca = PCA(n_components=self.args.K) 27 | # self.pca = manifold.TSNE(n_components=self.args.K, init='pca') 28 | 29 | def _init_data(self): 30 | self.data = DL(self.args) 31 | self.dl = self.data.dl 32 | 33 | def feature_extract(self): 34 | outputs = [] 35 | labels = [] 36 | print("进行特征提取...") 37 | for inputs, targets in tqdm(self.dl, ncols=90): 38 | inputs = inputs.to(self.device) 39 | targets = targets.numpy() 40 | output = self.net(inputs).detach().cpu().numpy() 41 | outputs.append(output) 42 | labels.append(targets) 43 | 44 | X = np.concatenate(outputs, axis=0) 45 | y = np.concatenate(labels, axis=0) 46 | 47 | data = {'X': X, 'y': y} 48 | io.savemat('results/%s.mat' % self.args.dataset, data) 49 | 50 | def train(self): 51 | print("数据集: ", self.args.dataset) 52 | print("train ratio: ", self.args.ratio) 53 | 54 | print("读取数据集...") 55 | data = io.loadmat('results/%s.mat' % self.args.dataset) 56 | X, y = data['X'], data['y'].squeeze() 57 | print("pca降维...") 58 | X = self.pca.fit_transform(X) 59 | print("划分数据集...") 60 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=self.args.ratio) 61 | self.svm.fit(X_train, y_train) 62 | pred = self.svm.predict(X_test) 63 | acc = accuracy_score(y_test, pred) 64 | print('val_acc: %.6f' % acc) 65 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 42 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import vgg16 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | def get_conv(start, end, model='vgg16'): 7 | conv1, conv2, conv3 = None, None, None 8 | if model == 'vgg16': 9 | net = vgg16(pretrained=True) 10 | return net.features[start:end] 11 | 12 | return None 13 | 14 | 15 | class BackBone(nn.Module): 16 | def __init__(self, in_features, out_features): 17 | super(BackBone, self).__init__() 18 | self.net = nn.Sequential( 19 | nn.Conv2d(in_features, out_features, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), 22 | ) 23 | 24 | def forward(self, x): 25 | return self.net(x) 26 | 27 | 28 | class Feature_Extraction(nn.Module): 29 | def __init__(self, args): 30 | super(Feature_Extraction, self).__init__() 31 | self.layer1 = get_conv(0, 19, args.model) 32 | self.layer2 = get_conv(19, 26, args.model) 33 | self.layer3 = get_conv(26, 31, args.model) 34 | 35 | self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=4, padding=0) 36 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 37 | 38 | def forward(self, imgs): 39 | x1 = self.layer1(imgs) 40 | x2 = self.layer2(x1) 41 | x3 = self.layer3(x2) 42 | new_x1 = x3 43 | new_x2 = self.maxpool2(x2) 44 | new_x3 = self.maxpool4(x1) 45 | x = torch.cat([new_x1, new_x2, new_x3], dim=1) 46 | 47 | return x 48 | 49 | 50 | class SAFF(nn.Module): 51 | def __init__(self, a=0.5, b=2, sigma=0.0001): 52 | super(SAFF, self).__init__() 53 | self.a = a 54 | self.b = b 55 | self.sigma = sigma 56 | 57 | def forward(self, x): 58 | """ 59 | :param x: (n, c, h, w) 60 | :return: 61 | """ 62 | n, K, h, w = x.shape 63 | S = x.sum(dim=1) # n,h,w 64 | z = torch.sum(S ** self.a, dim=[1, 2]) 65 | z = (z ** (1 / self.a)).view(n, 1, 1) 66 | S = (S / z) ** (1 / self.b) 67 | S = S.unsqueeze(1) 68 | new_x = (x * S).sum(dim=[2, 3]) 69 | omg = (x > 0).sum(dim=[2, 3]) / (256 ** 2) 70 | omg_sum = omg.sum(dim=1).unsqueeze(1) 71 | omg = (K * self.sigma + omg_sum) / (self.sigma + omg) 72 | omg = torch.log(omg) 73 | x = omg * new_x 74 | return x 75 | 76 | 77 | class Model(nn.Module): 78 | def __init__(self, args): 79 | super(Model, self).__init__() 80 | self.feature_extract = Feature_Extraction(args) 81 | self.saff = SAFF() 82 | 83 | def forward(self, img): 84 | 85 | x = self.feature_extract(img) 86 | x = self.saff(x) 87 | return x 88 | --------------------------------------------------------------------------------