├── imgs
├── model.png
└── saff.png
├── .gitignore
├── results
├── .gitkeep
└── feature_map.mat
├── .idea
├── .gitignore
├── vcs.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── modules.xml
├── misc.xml
├── 小样本分类 2k.iml
└── deployment.xml
├── requirements.txt
├── run.sh
├── dataset
├── util.py
└── dataLoader.py
├── run.py
├── README.md
├── Trainer.py
└── network.py
/imgs/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh-hike/SAFF/HEAD/imgs/model.png
--------------------------------------------------------------------------------
/imgs/saff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh-hike/SAFF/HEAD/imgs/saff.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | .DS_Store
3 | __pycache__
4 | *.ipynb
5 | *.mat
--------------------------------------------------------------------------------
/results/.gitkeep:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file !.gitkeep
--------------------------------------------------------------------------------
/results/feature_map.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh-hike/SAFF/HEAD/results/feature_map.mat
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.19.5
2 | scikit_learn==1.0.2
3 | scipy==1.5.0
4 | torch==1.11.0
5 | torchvision==0.12.0
6 | tqdm==4.49.0
7 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | python run.py --data_path /hy-tmp/data --extract --dataset SAR --ratio 0.8
4 |
5 |
6 | python run.py --data_path /hy-tmp/data --train --ratio 0.8 --dataset SAR
7 |
8 |
9 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/dataset/util.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import transforms as T
2 |
3 |
4 | def get_default_aug(dataset):
5 | size = 256
6 | if dataset == 'Aerial':
7 | size = 600
8 | aug = T.Compose([
9 | T.Resize((size, size)),
10 | T.ToTensor(),
11 | T.Normalize(mean=[0, 0, 0], std=[0.5, 0.5, 0.5])
12 | ])
13 |
14 | return aug
15 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/小样本分类 2k.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/dataset/dataLoader.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import ImageFolder
2 | import os
3 | from torch.utils.data import DataLoader
4 | from dataset.util import get_default_aug
5 |
6 |
7 | class DL:
8 | def __init__(self, args):
9 | path = args.data_path
10 | aug = get_default_aug(args.dataset)
11 | if args.dataset == 'NWPU':
12 | path = os.path.join(path, 'NWPU-RESISC45')
13 | elif args.dataset == 'UC':
14 | path = os.path.join(path, 'UCMerced_LandUse/Images')
15 | elif args.dataset == 'SAR':
16 | path = os.path.join(path, 'SAR/')
17 |
18 | data = ImageFolder(path, transform=aug)
19 | self.dl = DataLoader(data, batch_size=args.batch_size, shuffle=True, drop_last=True)
20 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from Trainer import Trainer
3 | import torch
4 |
5 | parser = argparse.ArgumentParser("classifier")
6 |
7 | parser.add_argument('--epochs', type=int, help="训练轮次", default=100)
8 | parser.add_argument('--batch_size', type=int, help="批次大小", default=200)
9 | parser.add_argument('--dataset', type=str, choices=['NWPU', 'UC', 'SAR'], default='NWPU')
10 | parser.add_argument('--data_path', type=str, help="数据集所在路径", default='/users/zhhike/desktop/dataset/')
11 | parser.add_argument('--model', type=str, choices=['vgg16'], default='vgg16')
12 | parser.add_argument('--K', type=int, help="降为参数", default=784)
13 | parser.add_argument('--ratio', type=float, help="训练比例", default=0.1)
14 | parser.add_argument('--extract', action='store_true', help="是否进行特征提取")
15 | parser.add_argument('--train', action='store_true', help="是否进行特征训练")
16 |
17 | args = parser.parse_args()
18 |
19 | if __name__ == "__main__":
20 | trainer = Trainer(args)
21 |
22 | with torch.no_grad():
23 | if args.extract:
24 | trainer.feature_extract()
25 | if args.train:
26 | trainer.train()
27 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Attention-Based Deep Feature Fusion for Remote Sensing Scene Classification
2 | Use vgg16 and SAFF for small sample classification from the [paper](https://paperswithcode.com/paper/self-attention-based-deep-feature-fusion-for).
3 |
4 | ## Introduction
5 | * Extract dataset features using pretrained vgg16
6 |

7 |
8 | * SAFF converts features into 1D tensor
9 |
10 | 
11 |
12 | ## Environmental preparation
13 | ```bash
14 | conda create -n zh python=3.9
15 | conda activate zh
16 | python3 -m pip install --upgrade pip
17 | pip3 install -r requirements.txt
18 | ```
19 |
20 | ## Run
21 | If your dataset is at path */hy-tmp/data*
22 | Suppose you want to train on the *UC* dataset.
23 |
24 | * Feature extraction
25 | ```
26 | python run.py
27 | --data_path /hy-tmp/data
28 | --extract
29 | --dataset UC
30 | ```
31 |
32 | * Train & verify
33 |
34 | ```
35 | python run.py
36 | --data_path /hy-tmp/data
37 | --train
38 | --dataset UC
39 | --ratio 0.8
40 | ```
41 |
42 | ## Experimental results
43 |
44 | | dataset | train_ratio | acc |
45 | |:-------:|:-----------:|:-----:|
46 | | NWPU | 0.1 | 66.49 |
47 | | NWPU | 0.2 | 73.13 |
48 | | UC | 0.8 | 92.5 |
49 | | SAR | 0.8 | 89.8 |
50 |
--------------------------------------------------------------------------------
/Trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dataset.dataLoader import DL
3 | from network import Model
4 | import torch.nn as nn
5 | from sklearn.svm import SVC
6 | import numpy as np
7 | from sklearn.metrics import accuracy_score
8 | from sklearn.decomposition import PCA
9 | from sklearn import manifold
10 | from sklearn.model_selection import train_test_split
11 | import scipy.io as io
12 | from tqdm import tqdm
13 |
14 |
15 | class Trainer:
16 | def __init__(self, args):
17 | self.args = args
18 | self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
19 | self._init_data()
20 | self._init_model()
21 |
22 | def _init_model(self):
23 | self.net = Model(self.args).to(self.device)
24 | self.opt = torch.optim.Adam(self.net.parameters(), )
25 | self.svm = SVC(kernel='rbf')
26 | self.pca = PCA(n_components=self.args.K)
27 | # self.pca = manifold.TSNE(n_components=self.args.K, init='pca')
28 |
29 | def _init_data(self):
30 | self.data = DL(self.args)
31 | self.dl = self.data.dl
32 |
33 | def feature_extract(self):
34 | outputs = []
35 | labels = []
36 | print("进行特征提取...")
37 | for inputs, targets in tqdm(self.dl, ncols=90):
38 | inputs = inputs.to(self.device)
39 | targets = targets.numpy()
40 | output = self.net(inputs).detach().cpu().numpy()
41 | outputs.append(output)
42 | labels.append(targets)
43 |
44 | X = np.concatenate(outputs, axis=0)
45 | y = np.concatenate(labels, axis=0)
46 |
47 | data = {'X': X, 'y': y}
48 | io.savemat('results/%s.mat' % self.args.dataset, data)
49 |
50 | def train(self):
51 | print("数据集: ", self.args.dataset)
52 | print("train ratio: ", self.args.ratio)
53 |
54 | print("读取数据集...")
55 | data = io.loadmat('results/%s.mat' % self.args.dataset)
56 | X, y = data['X'], data['y'].squeeze()
57 | print("pca降维...")
58 | X = self.pca.fit_transform(X)
59 | print("划分数据集...")
60 | X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=self.args.ratio)
61 | self.svm.fit(X_train, y_train)
62 | pred = self.svm.predict(X_test)
63 | acc = accuracy_score(y_test, pred)
64 | print('val_acc: %.6f' % acc)
65 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
29 |
30 |
31 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import vgg16
2 | import torch.nn as nn
3 | import torch
4 |
5 |
6 | def get_conv(start, end, model='vgg16'):
7 | conv1, conv2, conv3 = None, None, None
8 | if model == 'vgg16':
9 | net = vgg16(pretrained=True)
10 | return net.features[start:end]
11 |
12 | return None
13 |
14 |
15 | class BackBone(nn.Module):
16 | def __init__(self, in_features, out_features):
17 | super(BackBone, self).__init__()
18 | self.net = nn.Sequential(
19 | nn.Conv2d(in_features, out_features, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
20 | nn.ReLU(inplace=True),
21 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
22 | )
23 |
24 | def forward(self, x):
25 | return self.net(x)
26 |
27 |
28 | class Feature_Extraction(nn.Module):
29 | def __init__(self, args):
30 | super(Feature_Extraction, self).__init__()
31 | self.layer1 = get_conv(0, 19, args.model)
32 | self.layer2 = get_conv(19, 26, args.model)
33 | self.layer3 = get_conv(26, 31, args.model)
34 |
35 | self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=4, padding=0)
36 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
37 |
38 | def forward(self, imgs):
39 | x1 = self.layer1(imgs)
40 | x2 = self.layer2(x1)
41 | x3 = self.layer3(x2)
42 | new_x1 = x3
43 | new_x2 = self.maxpool2(x2)
44 | new_x3 = self.maxpool4(x1)
45 | x = torch.cat([new_x1, new_x2, new_x3], dim=1)
46 |
47 | return x
48 |
49 |
50 | class SAFF(nn.Module):
51 | def __init__(self, a=0.5, b=2, sigma=0.0001):
52 | super(SAFF, self).__init__()
53 | self.a = a
54 | self.b = b
55 | self.sigma = sigma
56 |
57 | def forward(self, x):
58 | """
59 | :param x: (n, c, h, w)
60 | :return:
61 | """
62 | n, K, h, w = x.shape
63 | S = x.sum(dim=1) # n,h,w
64 | z = torch.sum(S ** self.a, dim=[1, 2])
65 | z = (z ** (1 / self.a)).view(n, 1, 1)
66 | S = (S / z) ** (1 / self.b)
67 | S = S.unsqueeze(1)
68 | new_x = (x * S).sum(dim=[2, 3])
69 | omg = (x > 0).sum(dim=[2, 3]) / (256 ** 2)
70 | omg_sum = omg.sum(dim=1).unsqueeze(1)
71 | omg = (K * self.sigma + omg_sum) / (self.sigma + omg)
72 | omg = torch.log(omg)
73 | x = omg * new_x
74 | return x
75 |
76 |
77 | class Model(nn.Module):
78 | def __init__(self, args):
79 | super(Model, self).__init__()
80 | self.feature_extract = Feature_Extraction(args)
81 | self.saff = SAFF()
82 |
83 | def forward(self, img):
84 |
85 | x = self.feature_extract(img)
86 | x = self.saff(x)
87 | return x
88 |
--------------------------------------------------------------------------------