├── LICENSE ├── MSNet.png ├── README.md ├── SeT ├── __init__.py ├── detect.py ├── loss.py ├── mask.py └── train.py ├── dataset ├── __init__.py └── coast │ ├── coast.mat │ └── coast.py ├── main.py ├── metric.py ├── model ├── MSNet │ ├── MSNet.py │ └── __init__.py └── __init__.py ├── requirements.txt ├── select_bands.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 enter-i-username 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MSNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enter-i-username/MSNet/6433192e765e62050421c94daa7f2a45b254aabd/MSNet.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSNet 2 | 3 | This is the official implementation of the paper "MSNet: Self-Supervised Multiscale Network with Enhanced Separation Training for Hyperspectral Anomaly Detection" in IEEE Transactions on Geoscience and Remote Sensing (TGRS). Paper link: https://ieeexplore.ieee.org/document/10551851 (IEEE Xplore). 4 | 5 | ## Abstract 6 | 7 | Hyperspectral anomaly detection (HAD) has attracted increasing attention due to its economical and efficient applications. The main challenge lies in the data-starved problem of hyperspectral images (HSIs) and the costliness of manual annotation, making it heavily reliant on the model's adaptability and robustness to unseen scenes under limited samples. Self-supervised learning offers a solution to this urgency via mining meaningful representations from the data itself. One promising paradigm is leveraging untrained neural networks to reconstruct the background component for revealing anomalous information. Its capability stems from the network architecture and the training process rather than learning from expensive and strongly domain-dependent data, which is naturally applicable to HAD. In this paper, to handle the urgent requirement for self-supervised learning in HAD, we propose a multi-scale network (termed MSNet) that detects anomalies with enhanced separation training. The network architecture consists of several multi-scale convolutional encoder-decoder (CED) layers, considering the spatial characteristics of the anomalies. To suppress the anomalies during background reconstruction, we adopt a new separation training strategy by introducing a soft separator for better practicality on larger datasets. Extensive experiments conducted on 5 commonly used datasets and the HAD100 dataset, demonstrate the superiority of our method over its counterparts. 8 | 9 | ## Workflow 10 | 11 | MSNet 12 | 13 | ### 1. Band Selection 14 | We utilize a fast and efficient band selection algorithm [OPBS](https://ieeexplore.ieee.org/document/8320544) to eliminate the redundant information among bands, while reducing time costs for training. 15 | 16 | ### 2. Network Training 17 | We train the multi-scale network using the enhanced separation training loss and the multi-scale reconstruction loss. 18 | 19 | ### 3. Anomaly Detection 20 | The detection map is obtained by computing reconstruction errors between the input and output of the trained network. 21 | 22 | ## Getting Started 23 | 24 | ### Installing Dependencies 25 | To get started, please install the following packages in Python 3.8 environment: 26 | - matplotlib (version 3.5.2) 27 | - numpy (version 1.24.3) 28 | - scikit_learn (version 1.2.2) 29 | - scipy (version 1.10) 30 | - torch (version 1.13.1) 31 | - tqdm (version 4.65.0) 32 | 33 | by running the command: 34 | ``` 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ### Starting an Experiment 39 | 40 | We have prepared a demo program to start a simple experiment by running the command: 41 | ``` 42 | python main.py 43 | ``` 44 | 45 | In this program, we evaluate the network using the Coast dataset in `dataset/coast`. You can also include other datasets in the `dataset` directory using a format similar to "Coast". 46 | 47 | ## Citation 48 | 49 | If you find the code helpful in your research, please kindly cite our paper: 50 | 51 | ```bibtex 52 | @article{liu2024msnet, 53 | title={MSNet: Self-Supervised Multi-Scale Network with Enhanced Separation Training for Hyperspectral Anomaly Detection}, 54 | author={Liu, Haijun and Su, Xi and Shen, Xiangfei and Zhou, Xichuan}, 55 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 56 | year={2024}, 57 | publisher={IEEE} 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /SeT/__init__.py: -------------------------------------------------------------------------------- 1 | from .detect import detect 2 | from .mask import Mask 3 | from .loss import TotalLoss 4 | from .train import separation_training 5 | -------------------------------------------------------------------------------- /SeT/detect.py: -------------------------------------------------------------------------------- 1 | 2 | def detect(x, decoder_outputs): 3 | y = decoder_outputs[0] 4 | dm = (x - y).detach() 5 | dm = dm ** 2 6 | dm = dm.sum(2) 7 | 8 | return dm 9 | -------------------------------------------------------------------------------- /SeT/loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class LoG(nn.Module): 7 | 8 | def __init__(self, device): 9 | super(LoG, self).__init__() 10 | 11 | self.tmpl = torch.tensor( 12 | [[-2, -4, -4, -4, -2], 13 | [-4, 0, 8, 0, -4], 14 | [-4, 8, 24, 8, -4], 15 | [-4, 0, 8, 0, -4], 16 | [-2, -4, -4, -4, -2]] 17 | ).to(device).float() 18 | 19 | ws, ws = self.tmpl.shape 20 | self.tmpl = self.tmpl.reshape(1, 1, 1, ws, ws) 21 | self.pad = ws // 2 22 | 23 | def forward(self, x): 24 | 25 | x = x.permute(2, 0, 1).unsqueeze(0) 26 | 27 | # Reflection padding 28 | x = F.pad(x, (self.pad, self.pad, self.pad, self.pad), mode='reflect') 29 | x = x.unsqueeze(0) 30 | 31 | # Calculate LoG for each band 32 | x = F.conv3d(x, self.tmpl) 33 | 34 | x = x.squeeze(0).squeeze(0) 35 | x = x.permute(1, 2, 0) 36 | 37 | return x 38 | 39 | 40 | class SeTLoss(nn.Module): 41 | 42 | def __init__(self, lmda, device): 43 | super(SeTLoss, self).__init__() 44 | 45 | self.lmda = lmda 46 | self.eps = 1e-6 47 | 48 | self.log = LoG(device) 49 | self.norm = lambda x: (x ** 2).sum() 50 | 51 | def forward(self, **kwargs): 52 | anm_mask = kwargs['mask'] 53 | x = kwargs['x'] 54 | decoder_outputs = kwargs['y'] 55 | y = decoder_outputs[0] 56 | 57 | num_anm = anm_mask.count() 58 | 59 | bg_mask = anm_mask.not_op() 60 | num_bg = bg_mask.count() 61 | 62 | # Calculate LoG on the estimated image y 63 | log_y = self.log(y) 64 | 65 | # Anomaly suppression loss 66 | as_loss = self.norm(anm_mask.dot_prod(log_y)) / (num_anm + self.eps) 67 | 68 | # Background reconstruction loss 69 | br_loss = self.norm(bg_mask.dot_prod(x - y)) / num_bg 70 | 71 | # Separation training loss 72 | set_loss = br_loss + self.lmda * as_loss 73 | 74 | return set_loss 75 | 76 | 77 | class MSRLoss(nn.Module): 78 | 79 | def __init__(self): 80 | super(MSRLoss, self).__init__() 81 | 82 | self.norm = lambda x: (x ** 2).sum() 83 | 84 | def forward(self, **kwargs): 85 | x = kwargs['x'] 86 | decoder_outputs = kwargs['y'] 87 | 88 | scale = 1 89 | layers = [] 90 | 91 | for _do in decoder_outputs: 92 | _rows = _do.shape[0] // scale 93 | _cols = _do.shape[1] // scale 94 | _x_down = F.interpolate( 95 | x.permute(2, 0, 1).unsqueeze(0), 96 | size=(_rows, _cols), mode='bilinear' 97 | ) 98 | _do_down = F.interpolate( 99 | _do.permute(2, 0, 1).unsqueeze(0), 100 | size=(_rows, _cols), mode='bilinear' 101 | ) 102 | _layer = self.norm(_x_down - _do_down) / (_rows * _cols) 103 | layers.append(_layer) 104 | scale *= 2 105 | 106 | msr_loss = sum(layers) / len(layers) 107 | return msr_loss 108 | 109 | 110 | class TotalLoss(nn.Module): 111 | 112 | def __init__(self, lmda, device): 113 | super(TotalLoss, self).__init__() 114 | 115 | self.set_loss = SeTLoss(lmda, device) 116 | self.msr_loss = MSRLoss() 117 | 118 | def forward(self, **kwargs): 119 | rows, cols, bands = kwargs['x'].shape 120 | total_loss = self.set_loss(**kwargs) + self.msr_loss(**kwargs) 121 | total_loss /= bands 122 | return total_loss 123 | -------------------------------------------------------------------------------- /SeT/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union 3 | import utils 4 | 5 | 6 | class Mask: 7 | 8 | def __init__(self, 9 | init: Union[tuple, torch.Tensor], 10 | device): 11 | 12 | if isinstance(init, tuple): 13 | self.mask = torch.zeros(init).to(device) 14 | self.mask = self.mask.unsqueeze(-1) 15 | elif isinstance(init, torch.Tensor): 16 | self.mask = init 17 | 18 | def dot_prod(self, x): 19 | return self.mask * x 20 | 21 | def count(self): 22 | return self.mask.sum() 23 | 24 | def not_op(self): 25 | return Mask(1 - self.mask, self.mask.device) 26 | 27 | def update(self, 28 | dm: torch.Tensor): 29 | self.mask = utils.MinMaxNorm().fit(dm).transform(dm).unsqueeze(-1) 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /SeT/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .detect import detect 3 | from tqdm import tqdm 4 | from metric import roc_auc 5 | import numpy as np 6 | 7 | 8 | def train_model( 9 | x, 10 | model, 11 | criterion, 12 | cri_kwargs, 13 | epochs, 14 | optimizer, 15 | verbose): 16 | 17 | epoch_iter = iter(_ for _ in range(epochs)) 18 | if verbose: 19 | epoch_iter = tqdm(list(epoch_iter)) 20 | for _ in epoch_iter: 21 | 22 | # Clear gradient information 23 | optimizer.zero_grad() 24 | 25 | # Forward propagation 26 | y = model(x) 27 | 28 | # Calculate loss 29 | loss = criterion(x=x, y=y, **cri_kwargs) 30 | 31 | # Backward propagation 32 | loss.backward() 33 | 34 | # Update network parameters 35 | optimizer.step() 36 | 37 | if verbose: 38 | epoch_iter.set_postfix({'loss': '{0:.4f}'.format(loss)}) 39 | 40 | 41 | def separation_training( 42 | x: torch.Tensor, 43 | gt: np.ndarray, 44 | model, 45 | loss, 46 | mask, 47 | optimizer, 48 | epochs, 49 | output_iter, 50 | max_iter, 51 | verbose) -> (np.ndarray, list): 52 | """ 53 | The main process of the separation training algorithm. 54 | 55 | """ 56 | 57 | history = [] 58 | output_dm = np.zeros_like(gt) 59 | 60 | for i in range(1, max_iter + 1): 61 | if verbose: 62 | print('Iter {0}'.format(i)) 63 | 64 | # Feed the model with x 65 | model_input = x 66 | 67 | # Train the model for some epochs 68 | train_model( 69 | model_input, 70 | model, 71 | loss, 72 | {'mask': mask}, 73 | epochs, 74 | optimizer, 75 | verbose 76 | ) 77 | 78 | # Update the mask using detection map obtained in this iteration 79 | dm = detect(x, model(model_input)) 80 | mask.update(dm.detach()) 81 | 82 | # Evaluation 83 | np_dm = dm.cpu().detach().numpy() 84 | fpr, tpr, auc = roc_auc(np_dm, gt) 85 | if verbose: 86 | print('Current AUC score: {0:.4f}'.format(auc)) 87 | 88 | # Record history 89 | history.append(auc) 90 | 91 | # Record the output detection map of the algorithm 92 | if i == output_iter: 93 | output_dm = np_dm 94 | 95 | return output_dm, history 96 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .coast import coast 2 | -------------------------------------------------------------------------------- /dataset/coast/coast.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enter-i-username/MSNet/6433192e765e62050421c94daa7f2a45b254aabd/dataset/coast/coast.mat -------------------------------------------------------------------------------- /dataset/coast/coast.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import os 3 | 4 | 5 | name = 'coast' 6 | path = os.path.dirname(__file__) 7 | file_name = os.path.join(path, name + '.mat') 8 | 9 | 10 | def get_data(): 11 | mat = sio.loadmat(file_name) 12 | data = mat['data'].astype(float) 13 | gt = mat['map'].astype(bool) 14 | 15 | return data, gt 16 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from dataset import coast 2 | from model import MSNet 3 | import matplotlib.pyplot as plt 4 | from torch.optim import Adam 5 | import select_bands 6 | import torch 7 | import utils 8 | import metric 9 | import os 10 | from SeT import ( 11 | TotalLoss, 12 | Mask, 13 | separation_training 14 | ) 15 | 16 | # Settings 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | lmda = 1e-3 19 | num_bs = 64 20 | num_layers = 3 21 | lr = 1e-3 22 | epochs = 150 23 | output_iter = 5 24 | max_iter = 10 25 | data_norm = True 26 | Net = MSNet 27 | net_kwargs = dict() 28 | net_kwargs['num_layers'] = num_layers 29 | 30 | # Load data 31 | dataset = coast 32 | data, gt = dataset.get_data() 33 | rows, cols, bands = data.shape 34 | net_kwargs['shape'] = (rows, cols, num_bs) 35 | print('Detecting on %s...' % dataset.name) 36 | 37 | # Preprocessing 38 | band_idx = select_bands.OPBS(data, num_bs) 39 | data_bs = data[:, :, band_idx] 40 | if data_norm: 41 | data_bs = utils.ZScoreNorm().fit(data_bs).transform(data_bs) 42 | 43 | # Load model 44 | model = Net(**net_kwargs).to(device).float() 45 | 46 | # Loss 47 | loss = TotalLoss(lmda, device) 48 | 49 | # Mask 50 | mask = Mask((rows, cols), device) 51 | 52 | # Optimizer 53 | optimizer = Adam(model.parameters(), lr=lr) 54 | 55 | # Separation Training 56 | x_bs = torch.from_numpy(data_bs).to(device).float() 57 | pr_dm, history = separation_training( 58 | x=x_bs, 59 | gt=gt, 60 | model=model, 61 | loss=loss, 62 | mask=mask, 63 | optimizer=optimizer, 64 | epochs=epochs, 65 | output_iter=output_iter, 66 | max_iter=max_iter, 67 | verbose=True 68 | ) 69 | 70 | # Save the detection result 71 | result_path = os.path.join('results', model.name) 72 | if not os.path.exists(result_path): 73 | os.makedirs(result_path) 74 | 75 | rx_dm = utils.rx(data) 76 | fpr, tpr, rx_auc = metric.roc_auc(rx_dm, gt) 77 | plt.plot(fpr, tpr, label='RX: %.4f' % rx_auc) 78 | 79 | fpr, tpr, pr_auc = metric.roc_auc(pr_dm, gt) 80 | plt.plot(fpr, tpr, label='%s+SeT: %.4f' % (model.name, pr_auc), 81 | c='black', alpha=0.7) 82 | 83 | plt.grid(alpha=0.3) 84 | plt.legend() 85 | plt.savefig( 86 | os.path.join(result_path, '%s_roc.pdf' % dataset.name) 87 | ) 88 | plt.clf() 89 | plt.close() 90 | 91 | iters = [(_ + 1) * epochs for _ in range(max_iter)] 92 | plt.xticks(iters) 93 | plt.plot(iters, history) 94 | plt.scatter([output_iter * epochs], [history[output_iter - 1]], 95 | marker='o', edgecolors='black', facecolors='white', label='Stop', 96 | zorder=10) 97 | plt.grid(alpha=0.3) 98 | plt.xlabel('Epoch') 99 | plt.ylabel('AUC') 100 | plt.legend() 101 | plt.savefig( 102 | os.path.join(result_path, '%s_auc_history.pdf' % dataset.name) 103 | ) 104 | plt.clf() 105 | plt.close() 106 | 107 | print('Complete.') 108 | print('Results are saved in results/.') 109 | 110 | 111 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | import numpy as np 3 | 4 | 5 | def roc_auc(dm: np.ndarray, 6 | gt: np.ndarray): 7 | rows, cols = gt.shape 8 | 9 | gt = gt.reshape(rows * cols) 10 | dm = dm.reshape(rows * cols) 11 | 12 | fpr, tpr, _ = metrics.roc_curve(gt, dm) 13 | auc = metrics.auc(fpr, tpr) 14 | 15 | return fpr, tpr, auc 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /model/MSNet/MSNet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.functional import interpolate 3 | import torch 4 | 5 | 6 | class Bottleneck(nn.Module): 7 | 8 | def __init__(self, 9 | in_channels, 10 | hidden_channels): 11 | super(Bottleneck, self).__init__() 12 | 13 | self.layers = nn.Sequential(*[ 14 | nn.Conv2d(in_channels, hidden_channels, 3, 1, 1), 15 | nn.ReLU(), 16 | nn.Conv2d(hidden_channels, in_channels, 3, 1, 1), 17 | ]) 18 | 19 | def forward(self, x): 20 | x = self.layers(x) + x 21 | return x 22 | 23 | 24 | class Down(nn.Module): 25 | 26 | def __init__(self, 27 | down_rate,): 28 | super(Down, self).__init__() 29 | 30 | self.layers = nn.Sequential(*[ 31 | nn.ReLU(), 32 | nn.AvgPool2d(down_rate) 33 | ]) 34 | 35 | def forward(self, x): 36 | return self.layers(x) 37 | 38 | 39 | class Up(nn.Module): 40 | 41 | def __init__(self, 42 | shape, 43 | need_relu): 44 | super(Up, self).__init__() 45 | 46 | self.need_relu = need_relu 47 | self.rows, self.cols, bands = shape 48 | 49 | self.conv = nn.Conv2d(bands, bands, 1) 50 | self.relu = nn.ReLU() 51 | 52 | def forward(self, x): 53 | x = interpolate(x, size=(self.rows, self.cols), mode='bilinear') 54 | x = self.conv(x) 55 | if self.need_relu: 56 | x = self.relu(x) 57 | return x 58 | 59 | 60 | class Encoder(nn.Module): 61 | 62 | def __init__(self, 63 | shape, 64 | down_rate): 65 | super(Encoder, self).__init__() 66 | 67 | rows, cols, bands = shape 68 | 69 | self.bottleneck = Bottleneck(bands, 16) 70 | self.down_block = Down(down_rate) 71 | self.up_block = Up(shape, need_relu=False) 72 | 73 | def forward(self, x): 74 | output = self.bottleneck(x) 75 | output = self.down_block(output) 76 | sm = self.up_block(output) 77 | return output, sm 78 | 79 | 80 | class Decoder(nn.Module): 81 | 82 | def __init__(self, 83 | shape): 84 | super(Decoder, self).__init__() 85 | 86 | rows, cols, bands = shape 87 | 88 | self.up_block = Up(shape, need_relu=True) 89 | self.conv = nn.Conv2d(bands, bands, 1) 90 | 91 | def forward(self, encoder_output, sm): 92 | output = self.up_block(encoder_output) 93 | output = output + sm 94 | output = self.conv(output) 95 | return output 96 | 97 | 98 | class MSNet(nn.Module): 99 | 100 | def __init__(self, 101 | **kwargs,): 102 | super(MSNet, self).__init__() 103 | 104 | self.name = 'MSNet' 105 | 106 | self.num_layers = kwargs['num_layers'] 107 | rows, cols, bands = kwargs['shape'] 108 | 109 | self.encoders = nn.ModuleList([ 110 | Encoder(shape=(rows, cols, bands), down_rate=2 ** _l) 111 | for _l in range(self.num_layers) 112 | ]) 113 | 114 | self.decoders = nn.ModuleList([ 115 | Decoder(shape=(rows, cols, bands)) 116 | for _l in range(self.num_layers) 117 | ]) 118 | 119 | def forward(self, x): 120 | x = x.permute(2, 0, 1).unsqueeze(0) 121 | 122 | decoding_list = [] 123 | 124 | # Encoding 125 | encoding_sum = x 126 | for _l in range(self.num_layers): 127 | output, sm = self.encoders[_l](encoding_sum) 128 | decoding_list.append(output) 129 | encoding_sum = sm + encoding_sum 130 | 131 | # Decoding 132 | decoding_sum = torch.zeros_like(x) 133 | for _cd in range(self.num_layers - 1, -1, -1): 134 | encoder_output = decoding_list[_cd] 135 | decoder_output = self.decoders[_cd](encoder_output, decoding_sum) 136 | decoding_sum = decoder_output + decoding_sum 137 | decoding_list[_cd] = decoder_output 138 | 139 | to_orig_shape = map( 140 | lambda _x: _x.squeeze(0).permute(1, 2, 0), 141 | decoding_list 142 | ) 143 | return tuple(to_orig_shape) 144 | 145 | -------------------------------------------------------------------------------- /model/MSNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enter-i-username/MSNet/6433192e765e62050421c94daa7f2a45b254aabd/model/MSNet/__init__.py -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .MSNet.MSNet import MSNet 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.2 2 | numpy==1.24.3 3 | scikit_learn==1.2.2 4 | scipy==1.10 5 | torch==1.13.1 6 | tqdm==4.65.0 7 | -------------------------------------------------------------------------------- /select_bands.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def OPBS( 5 | x: np.ndarray, 6 | num_bs: int 7 | ): 8 | """ 9 | Ref: 10 | W. Zhang, X. Li, Y. Dou, and L. Zhao, “A geometry-based band 11 | selection approach for hyperspectral image analysis,” IEEE Transactions 12 | on Geoscience and Remote Sensing, vol. 56, no. 8, pp. 4318–4333, 2018. 13 | """ 14 | rows, cols, bands = x.shape 15 | eps = 1e-9 16 | 17 | x_2d = np.reshape(x, (rows * cols, bands)) 18 | y_2d = x_2d.copy() 19 | h = np.zeros(bands) 20 | band_idx = [] 21 | 22 | idx = np.argmax(np.var(x_2d, axis=0)) 23 | band_idx.append(idx) 24 | h[idx] = np.sum(x_2d[:, band_idx[-1]] ** 2) 25 | 26 | i = 1 27 | while i < num_bs: 28 | id_i_1 = band_idx[i - 1] 29 | 30 | _elem, _idx = -np.inf, 0 31 | for t in range(bands): 32 | if t not in band_idx: 33 | y_2d[:, t] = y_2d[:, t] - y_2d[:, id_i_1] * (np.dot(y_2d[:, id_i_1], y_2d[:, t]) / (h[id_i_1] + eps)) 34 | h[t] = np.dot(y_2d[:, t], y_2d[:, t]) 35 | 36 | if h[t] > _elem: 37 | _elem = h[t] 38 | _idx = t 39 | 40 | band_idx.append(_idx) 41 | i += 1 42 | 43 | band_idx = sorted(band_idx) 44 | return band_idx 45 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rx(x: np.ndarray): 5 | rows, cols, bands = x.shape 6 | x_2d = x.reshape((rows * cols, bands)) 7 | 8 | inv_Sig = np.linalg.inv(np.cov(x_2d.T)) 9 | mu = np.mean(x_2d, axis=0, keepdims=True) 10 | get_M_dis = lambda _x: (_x - mu) @ inv_Sig @ (_x - mu).T 11 | dm = np.array([get_M_dis(_x) for _x in x_2d]) 12 | 13 | dm = dm.reshape((rows, cols)) 14 | 15 | return dm 16 | 17 | 18 | class MinMaxNorm: 19 | 20 | def __init__(self, feature_range=(0, 1)): 21 | self.feature_range = feature_range 22 | self.min = None 23 | self.max = None 24 | 25 | def fit(self, x): 26 | self.min = x.min() 27 | self.max = x.max() 28 | return self 29 | 30 | def transform(self, x): 31 | x_std = (x - self.min) / (self.max - self.min) 32 | x_norm = x_std * (self.feature_range[1] - self.feature_range[0]) + self.feature_range[0] 33 | return x_norm 34 | 35 | def inverse_transform(self, x_norm): 36 | x_std = (x_norm - self.feature_range[0]) / (self.feature_range[1] - self.feature_range[0]) 37 | x = x_std * (self.max - self.min) + self.min 38 | return x 39 | 40 | 41 | class ZScoreNorm: 42 | 43 | def __init__(self): 44 | self.means = None 45 | self.stds = None 46 | 47 | def fit(self, x): 48 | self.means = np.mean(x, axis=(0, 1)) 49 | self.stds = np.std(x, axis=(0, 1)) 50 | return self 51 | 52 | def transform(self, x): 53 | x_norm = np.zeros_like(x) 54 | for i in range(x.shape[2]): 55 | x_norm[:, :, i] = (x[:, :, i] - self.means[i]) / self.stds[i] 56 | return x_norm 57 | 58 | def inverse_transform(self, x_norm): 59 | x = np.zeros_like(x_norm) 60 | for i in range(x_norm.shape[2]): 61 | x[:, :, i] = x_norm[:, :, i] * self.stds[i] + self.means[i] 62 | return x 63 | --------------------------------------------------------------------------------