├── LICENSE ├── README.md ├── common.py ├── data ├── HStest.py ├── HStrain.py └── __init__.py ├── demo.sh ├── loss.py ├── main_CST.py ├── metrics.py ├── network ├── CST.py └── csa.py ├── test_demo.sh └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tomchenshi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CST 2 | The pytorch code of hyperspectral image super-resolution method CST 3 | 4 | ## Requirements 5 | * Python 3.6.13 6 | * Pytorch 1.8. 7 | 8 | ## Preparation 9 | To get the training set, validation set and testing set, refer to SSPSR to download the mcodes for cropping the hyperspectral image. 10 | 11 | ## Training 12 | To train CST, run the following command.
13 | ``` 14 | sh demo.sh 15 | ``` 16 | ## Testing 17 | run the the following command.
18 | ``` 19 | sh test_demo.sh 20 | ``` 21 | ## References 22 | * [SSPSR](https://github.com/junjun-jiang/SSPSR) 23 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | 8 | 9 | def default_conv(in_channels, out_channels, kernel_size, stride=1, bias=True, dilation=1, groups=1): 10 | if dilation==1: 11 | return nn.Conv2d( 12 | in_channels, out_channels, kernel_size, 13 | padding=(kernel_size//2), bias=bias, groups=groups) 14 | elif dilation==2: 15 | return nn.Conv2d( 16 | in_channels, out_channels, kernel_size, 17 | padding=2, bias=bias, dilation=dilation, groups=groups) 18 | 19 | else: 20 | padding = int((kernel_size - 1) / 2) * dilation 21 | return nn.Conv2d( 22 | in_channels, out_channels, kernel_size, 23 | stride, padding=padding, bias=bias, dilation=dilation, groups=groups) 24 | 25 | 26 | class CALayer(nn.Module): 27 | def __init__(self, channel, reduction=16): 28 | super(CALayer, self).__init__() 29 | # global average pooling: feature --> point 30 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 31 | # feature channel downscale and upscale --> channel weight 32 | self.conv_du = nn.Sequential( 33 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 36 | nn.Sigmoid() 37 | ) 38 | 39 | def forward(self, x): 40 | y = self.avg_pool(x) 41 | y = self.conv_du(y) 42 | return x * y 43 | 44 | def mean_channels(F): 45 | assert(F.dim() == 4) 46 | spatial_sum = F.sum(3, keepdim=True).sum(2, keepdim=True) 47 | return spatial_sum / (F.size(2) * F.size(3)) 48 | 49 | def stdv_channels(F): 50 | assert(F.dim() == 4) 51 | F_mean = mean_channels(F) 52 | F_variance = (F - F_mean).pow(2).sum(3, keepdim=True).sum(2, keepdim=True) / (F.size(2) * F.size(3)) 53 | return F_variance.pow(0.5) 54 | 55 | class Upsampler(nn.Sequential): 56 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 57 | m = [] 58 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 59 | for _ in range(int(math.log(scale, 2))): 60 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 61 | m.append(nn.PixelShuffle(2)) 62 | if bn: 63 | m.append(nn.BatchNorm2d(n_feats)) 64 | if act == 'relu': 65 | m.append(nn.ReLU(True)) 66 | elif act == 'prelu': 67 | m.append(nn.PReLU(n_feats)) 68 | 69 | elif scale == 3: 70 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 71 | m.append(nn.PixelShuffle(3)) 72 | if bn: 73 | m.append(nn.BatchNorm2d(n_feats)) 74 | if act == 'relu': 75 | m.append(nn.ReLU(True)) 76 | elif act == 'prelu': 77 | m.append(nn.PReLU(n_feats)) 78 | else: 79 | raise NotImplementedError 80 | 81 | super(Upsampler, self).__init__(*m) 82 | 83 | 84 | class ResAttentionBlock(nn.Module): 85 | def __init__(self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 86 | super(ResAttentionBlock, self).__init__() 87 | m = [] 88 | for i in range(2): 89 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 90 | if bn: 91 | m.append(nn.BatchNorm2d(n_feats)) 92 | if i == 0: 93 | m.append(act) 94 | 95 | m.append(CALayer(n_feats, 16)) 96 | 97 | self.body = nn.Sequential(*m) 98 | self.res_scale = res_scale 99 | 100 | def forward(self, x): 101 | res = self.body(x).mul(self.res_scale) 102 | res += x 103 | return res -------------------------------------------------------------------------------- /data/HStest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.utils.data as data 3 | import scipy.io as sio 4 | import torch 5 | 6 | 7 | class HSTestData(data.Dataset): 8 | def __init__(self, image_dir, use_3D=False): 9 | test_data = sio.loadmat(image_dir) 10 | self.use_3Dconv = use_3D 11 | self.ms = np.array(test_data['ms'][...], dtype=np.float32) 12 | self.lms = np.array(test_data['ms_bicubic'][...], dtype=np.float32) 13 | self.gt = np.array(test_data['gt'][...], dtype=np.float32) 14 | 15 | def __getitem__(self, index): 16 | gt = self.gt[index, :, :, :] 17 | ms = self.ms[index, :, :, :] 18 | lms = self.lms[index, :, :, :] 19 | if self.use_3Dconv: 20 | ms, lms, gt = ms[np.newaxis, :, :, :], lms[np.newaxis, :, :, :], gt[np.newaxis, :, :, :] 21 | ms = torch.from_numpy(ms.copy()).permute(0, 3, 1, 2) 22 | lms = torch.from_numpy(lms.copy()).permute(0, 3, 1, 2) 23 | gt = torch.from_numpy(gt.copy()).permute(0, 3, 1, 2) 24 | else: 25 | ms = torch.from_numpy(ms.copy()).permute(2, 0, 1) 26 | lms = torch.from_numpy(lms.copy()).permute(2, 0, 1) 27 | gt = torch.from_numpy(gt.copy()).permute(2, 0, 1) 28 | #ms = torch.from_numpy(ms.transpose((2, 0, 1))) 29 | #lms = torch.from_numpy(lms.transpose((2, 0, 1))) 30 | #gt = torch.from_numpy(gt.transpose((2, 0, 1))) 31 | return ms, lms, gt 32 | 33 | def __len__(self): 34 | return self.gt.shape[0] 35 | -------------------------------------------------------------------------------- /data/HStrain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.utils.data as data 3 | import scipy.io as sio 4 | import torch 5 | import os 6 | import utils 7 | 8 | 9 | def is_mat_file(filename): 10 | return any(filename.endswith(extension) for extension in [".mat"]) 11 | 12 | 13 | class HSTrainingData(data.Dataset): 14 | def __init__(self, image_dir, augment=None, use_3D=False): 15 | self.image_files = [os.path.join(image_dir, x) for x in os.listdir(image_dir) if is_mat_file(x)] 16 | # self.image_files = [] 17 | # for i in self.image_folders: 18 | # images = os.listdir(i) 19 | # for j in images: 20 | # if is_mat_file(j): 21 | # full_path = os.path.join(i, j) 22 | # self.image_files.append(full_path) 23 | self.augment = augment 24 | self.use_3Dconv = use_3D 25 | if self.augment: 26 | self.factor = 8 27 | else: 28 | self.factor = 1 29 | 30 | def __getitem__(self, index): 31 | file_index = index 32 | aug_num = 0 33 | if self.augment: 34 | file_index = index // self.factor # 35 | aug_num = int(index % self.factor) # 0-7 36 | load_dir = self.image_files[file_index] 37 | data = sio.loadmat(load_dir) 38 | ms = np.array(data['ms'][...], dtype=np.float32) 39 | lms = np.array(data['ms_bicubic'][...], dtype=np.float32) 40 | gt = np.array(data['gt'][...], dtype=np.float32) 41 | ms, lms, gt = utils.data_augmentation(ms, mode=aug_num), utils.data_augmentation(lms, mode=aug_num), \ 42 | utils.data_augmentation(gt, mode=aug_num) 43 | if self.use_3Dconv: 44 | ms, lms, gt = ms[np.newaxis, :, :, :], lms[np.newaxis, :, :, :], gt[np.newaxis, :, :, :] 45 | ms = torch.from_numpy(ms.copy()).permute(0, 3, 1, 2) 46 | lms = torch.from_numpy(lms.copy()).permute(0, 3, 1, 2) 47 | gt = torch.from_numpy(gt.copy()).permute(0, 3, 1, 2) 48 | else: 49 | ms = torch.from_numpy(ms.copy()).permute(2, 0, 1) 50 | lms = torch.from_numpy(lms.copy()).permute(2, 0, 1) 51 | gt = torch.from_numpy(gt.copy()).permute(2, 0, 1) 52 | return ms, lms, gt 53 | 54 | def __len__(self): 55 | return len(self.image_files)*self.factor 56 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .HStrain import HSTrainingData 2 | from .HStest import HSTestData 3 | -------------------------------------------------------------------------------- /demo.sh: -------------------------------------------------------------------------------- 1 | python main_CST.py train --dataset "Chikusei" --n_scale 4 --gpus "0,1" 2 | 3 | python main_CST.py train --dataset "Houston" --n_scale 4 --gpus "0,1" 4 | 5 | python main_CST.py train --dataset "Pavia" --n_scale 4 --gpus "0,1" 6 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class HLoss(torch.nn.Module): 6 | def __init__(self, la1, la2, sam=True, gra=True): 7 | super(HLoss, self).__init__() 8 | self.lamd1 = la1 9 | self.lamd2 = la2 10 | self.sam = sam 11 | self.gra = gra 12 | 13 | self.fidelity = torch.nn.L1Loss() 14 | self.gra = torch.nn.L1Loss() 15 | 16 | def forward(self, y, gt): 17 | loss1 = self.fidelity(y, gt) 18 | loss2 = self.lamd1 * cal_sam(y, gt) 19 | loss3 = self.lamd2 * self.gra(cal_gradient(y), cal_gradient(gt)) 20 | loss = loss1 + loss2 + loss3 21 | return loss 22 | 23 | 24 | class HyLoss(torch.nn.Module): 25 | def __init__(self, la1=0.1): 26 | super(HyLoss, self).__init__() 27 | self.lamd1 = la1 28 | self.fidelity = torch.nn.L1Loss() 29 | 30 | def forward(self, y, gt): 31 | loss1 = self.fidelity(y, gt) 32 | loss2 = self.lamd1 * cal_sam(y, gt) 33 | loss = loss1 + loss2 34 | return loss 35 | 36 | class HybridLoss(torch.nn.Module): 37 | def __init__(self, lamd=1e-1, spatial_tv=False, spectral_tv=False): 38 | super(HybridLoss, self).__init__() 39 | self.lamd = lamd 40 | self.use_spatial_TV = spatial_tv 41 | self.use_spectral_TV = spectral_tv 42 | self.fidelity = torch.nn.L1Loss() 43 | self.spatial = TVLoss(weight=1e-3) 44 | self.spectral = TVLossSpectral(weight=1e-3) 45 | 46 | def forward(self, y, gt): 47 | loss = self.fidelity(y, gt) 48 | spatial_TV = 0.0 49 | spectral_TV = 0.0 50 | if self.use_spatial_TV: 51 | spatial_TV = self.spatial(y) 52 | if self.use_spectral_TV: 53 | spectral_TV = self.spectral(y) 54 | total_loss = loss + spatial_TV + spectral_TV 55 | return total_loss 56 | 57 | 58 | # from https://github.com/jxgu1016/Total_Variation_Loss.pytorch with slight modifications 59 | class TVLoss(torch.nn.Module): 60 | def __init__(self, weight=1.0): 61 | super(TVLoss, self).__init__() 62 | self.TVLoss_weight = weight 63 | 64 | def forward(self, x): 65 | batch_size = x.size()[0] 66 | h_x = x.size()[2] 67 | w_x = x.size()[3] 68 | count_h = self._tensor_size(x[:, :, 1:, :]) 69 | count_w = self._tensor_size(x[:, :, :, 1:]) 70 | # h_tv = torch.abs(x[:, :, 1:, :] - x[:, :, :h_x - 1, :]).sum() 71 | # w_tv = torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x - 1]).sum() 72 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 73 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 74 | return self.TVLoss_weight * (h_tv / count_h + w_tv / count_w) / batch_size 75 | 76 | def _tensor_size(self, t): 77 | return t.size()[1] * t.size()[2] * t.size()[3] 78 | 79 | 80 | class TVLossSpectral(torch.nn.Module): 81 | def __init__(self, weight=1.0): 82 | super(TVLossSpectral, self).__init__() 83 | self.TVLoss_weight = weight 84 | 85 | def forward(self, x): 86 | batch_size = x.size()[0] 87 | c_x = x.size()[1] 88 | count_c = self._tensor_size(x[:, 1:, :, :]) 89 | # c_tv = torch.abs((x[:, 1:, :, :] - x[:, :c_x - 1, :, :])).sum() 90 | c_tv = torch.pow((x[:, 1:, :, :] - x[:, :c_x - 1, :, :]), 2).sum() 91 | return self.TVLoss_weight * 2 * (c_tv / count_c) / batch_size 92 | 93 | def _tensor_size(self, t): 94 | return t.size()[1] * t.size()[2] * t.size()[3] 95 | 96 | def cal_sam(Itrue, Ifake): 97 | esp = 1e-6 98 | InnerPro = torch.sum(Itrue*Ifake,1,keepdim=True) 99 | len1 = torch.norm(Itrue, p=2,dim=1,keepdim=True) 100 | len2 = torch.norm(Ifake, p=2,dim=1,keepdim=True) 101 | divisor = len1*len2 102 | mask = torch.eq(divisor,0) 103 | divisor = divisor + (mask.float())*esp 104 | cosA = torch.sum(InnerPro/divisor,1).clamp(-1+esp, 1-esp) 105 | sam = torch.acos(cosA) 106 | sam = torch.mean(sam) / np.pi 107 | return sam 108 | 109 | 110 | def cal_gradient_c(x): 111 | c_x = x.size(1) 112 | g = x[:, 1:, 1:, 1:] - x[:, :c_x - 1, 1:, 1:] 113 | return g 114 | 115 | 116 | def cal_gradient_x(x): 117 | c_x = x.size(2) 118 | g = x[:, 1:, 1:, 1:] - x[:, 1:, :c_x - 1, 1:] 119 | return g 120 | 121 | 122 | def cal_gradient_y(x): 123 | c_x = x.size(3) 124 | g = x[:, 1:, 1:, 1:] - x[:, 1:, 1:, :c_x - 1] 125 | return g 126 | 127 | 128 | def cal_gradient(inp): 129 | x = cal_gradient_x(inp) 130 | y = cal_gradient_y(inp) 131 | c = cal_gradient_c(inp) 132 | g = torch.sqrt(torch.pow(x, 2) + torch.pow(y, 2) + torch.pow(c, 2) + 1e-6) 133 | return g -------------------------------------------------------------------------------- /main_CST.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import random 5 | import time 6 | import torch 7 | import cv2 8 | import math 9 | import numpy as np 10 | import torch.backends.cudnn as cudnn 11 | from torch.optim import Adam 12 | from torch.utils.data import DataLoader 13 | from tensorboardX import SummaryWriter 14 | from torchnet import meter 15 | import json 16 | from tqdm import tqdm 17 | from data import HSTrainingData 18 | from data import HSTestData 19 | from network.CST import * 20 | from common import * 21 | from metrics import compare_mpsnr 22 | # loss 23 | from loss import HLoss 24 | # from loss import HyLapLoss 25 | from metrics import quality_assessment 26 | 27 | # global settings 28 | resume = False 29 | log_interval = 50 30 | model_name = '' 31 | test_data_dir = '' 32 | 33 | 34 | def main(): 35 | # parsers 36 | main_parser = argparse.ArgumentParser(description="parser for SR network") 37 | subparsers = main_parser.add_subparsers(title="subcommands", dest="subcommand") 38 | train_parser = subparsers.add_parser("train", help="parser for training arguments") 39 | train_parser.add_argument("--cuda", type=int, required=False,default=1, 40 | help="set it to 1 for running on GPU, 0 for CPU") 41 | train_parser.add_argument("--batch_size", type=int, default=32, help="batch size, default set to 64") 42 | train_parser.add_argument("--epochs", type=int, default=300, help="epochs, default set to 20") 43 | train_parser.add_argument("--n_feats", type=int, default=180, help="n_feats, default set to 256") 44 | train_parser.add_argument("--n_scale", type=int, default=4, help="n_scale, default set to 2") 45 | train_parser.add_argument("--dataset_name", type=str, default="Chikusei", help="dataset_name, default set to dataset_name") 46 | train_parser.add_argument("--model_title", type=str, default="CST", help="model_title, default set to model_title") 47 | train_parser.add_argument("--seed", type=int, default=3000, help="start seed for model") 48 | train_parser.add_argument('--la1', type=float, default=0.3, help="") 49 | train_parser.add_argument('--la2', type=float, default=0.1, help="") 50 | train_parser.add_argument("--learning_rate", type=float, default=1e-4, 51 | help="learning rate, default set to 1e-4") 52 | train_parser.add_argument("--weight_decay", type=float, default=0, help="weight decay, default set to 0") 53 | train_parser.add_argument("--gpus", type=str, default="1", help="gpu ids (default: 7)") 54 | 55 | test_parser = subparsers.add_parser("test", help="parser for testing arguments") 56 | test_parser.add_argument("--cuda", type=int, required=False,default=1, 57 | help="set it to 1 for running on GPU, 0 for CPU") 58 | test_parser.add_argument("--gpus", type=str, default="0,1", help="gpu ids (default: 7)") 59 | test_parser.add_argument("--dataset_name", type=str, default="Chikusei",help="dataset_name, default set to dataset_name") 60 | test_parser.add_argument("--model_title", type=str, default="CST",help="model_title, default set to model_title") 61 | test_parser.add_argument("--n_feats", type=int, default=180, help="n_feats, default set to 256") 62 | test_parser.add_argument("--n_scale", type=int, default=4, help="n_scale, default set to 2") 63 | 64 | 65 | args = main_parser.parse_args() 66 | print('===>GPU:',args.gpus) 67 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 68 | if args.subcommand is None: 69 | print("ERROR: specify either train or test") 70 | sys.exit(1) 71 | if args.cuda and not torch.cuda.is_available(): 72 | print("ERROR: cuda is not available, try running on CPU") 73 | sys.exit(1) 74 | if args.subcommand == "train": 75 | train(args) 76 | else: 77 | test(args) 78 | pass 79 | 80 | 81 | def train(args): 82 | traintime = str(time.ctime()) 83 | device = torch.device("cuda" if args.cuda else "cpu") 84 | # args.seed = random.randint(1, 10000) 85 | print("Start seed: ", args.seed) 86 | torch.manual_seed(args.seed) 87 | if args.cuda: 88 | torch.cuda.manual_seed(args.seed) 89 | cudnn.benchmark = True 90 | 91 | print('===> Loading datasets') 92 | train_path = './datasets/'+args.dataset_name+'_x'+str(args.n_scale)+'/trains/' 93 | result_path = './results/' + args.dataset_name + '_x' + str(args.n_scale)+'/' 94 | test_data_dir = './datasets/'+args.dataset_name+'_x'+str(args.n_scale)+'/'+args.dataset_name+'_test.mat' 95 | 96 | train_set = HSTrainingData(image_dir=train_path, augment=True) 97 | 98 | train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=8, shuffle=True) 99 | test_set = HSTestData(test_data_dir) 100 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False) 101 | 102 | if args.dataset_name=='Cave': 103 | colors = 31 104 | elif args.dataset_name=='Pavia': 105 | colors = 102 106 | elif args.dataset_name=='Houston': 107 | colors = 48 108 | else: 109 | colors = 128 110 | 111 | print('===> Building model:{}'.format(args.model_title)) 112 | net = CST(inp_channels=colors, dim=args.n_feats, depths=[4,4,4,4], num_heads=[6,6,6,6],mlp_ratio=2, scale=args.n_scale) 113 | # print(net) 114 | model_title = args.dataset_name + "_" + args.model_title+'_x'+ str(args.n_scale) 115 | 116 | args.model_title = model_title 117 | 118 | if torch.cuda.device_count() > 1: 119 | print("===> Let's use", torch.cuda.device_count(), "GPUs.") 120 | net = torch.nn.DataParallel(net) 121 | start_epoch = 0 122 | 123 | if resume: 124 | model_name = './checkpoints/' + model_title + "_ckpt_epoch_" + str(300) + ".pth" 125 | if os.path.isfile(model_name): 126 | print("=> loading checkpoint '{}'".format(model_name)) 127 | checkpoint = torch.load(model_name) 128 | start_epoch = checkpoint["epoch"] 129 | net.load_state_dict(checkpoint["model"].state_dict()) 130 | else: 131 | print("=> no checkpoint found at '{}'".format(model_name)) 132 | net.to(device).train() 133 | print_network(net) 134 | # loss functions to choose 135 | h_loss = HLoss(args.la1,args.la2) 136 | 137 | 138 | print("===> Setting optimizer and logger") 139 | # add adam optimizer 140 | optimizer = Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 141 | epoch_meter = meter.AverageValueMeter() 142 | writer = SummaryWriter('runs/'+model_title+'_'+traintime) 143 | 144 | 145 | best_psnr = 0 146 | best_epoch = 0 147 | 148 | print('===> Start training') 149 | for e in range(start_epoch, args.epochs): 150 | psnr = [] 151 | adjust_learning_rate(args.learning_rate, optimizer, e+1) 152 | epoch_meter.reset() 153 | net.train() 154 | print("Start epoch {}, learning rate = {}".format(e + 1, optimizer.param_groups[0]["lr"])) 155 | for iteration, (x, lms, gt) in enumerate(tqdm(train_loader, leave=False)): 156 | x, lms, gt = x.to(device), lms.to(device), gt.to(device) 157 | psnr = [] 158 | optimizer.zero_grad() 159 | y = net(x, lms) 160 | loss = h_loss(y, gt) 161 | epoch_meter.add(loss.item()) 162 | loss.backward() 163 | optimizer.step() 164 | # tensorboard visualization 165 | if (iteration + log_interval) % log_interval == 0: 166 | print("===> {} \tEpoch[{}]({}/{}): Loss: {:.6f}".format(time.ctime(), e + 1, iteration + 1, len(train_loader)-1, loss.item())) 167 | n_iter = e * len(train_loader) + iteration + 1 168 | writer.add_scalar('scalar/train_loss', loss, n_iter) 169 | 170 | print("Running testset") 171 | net.eval() 172 | with torch.no_grad(): 173 | output = [] 174 | test_number = 0 175 | for i, (ms, lms, gt) in enumerate(test_loader): 176 | ms, lms, gt = ms.to(device), lms.to(device), gt.to(device) 177 | y = net(ms, lms) 178 | y, gt = y.squeeze().cpu().numpy().transpose(1, 2, 0), gt.squeeze().cpu().numpy().transpose(1, 2, 0) 179 | y = y[:gt.shape[0], :gt.shape[1], :] 180 | psnr_value = compare_mpsnr(gt, y, data_range=1.) 181 | psnr.append(psnr_value) 182 | output.append(y) 183 | test_number += 1 184 | 185 | avg_psnr = sum(psnr) / test_number 186 | if avg_psnr >best_psnr: 187 | best_psnr = avg_psnr 188 | best_epoch = e+1 189 | save_checkpoint(args, net, e + 1, traintime) 190 | writer.add_scalar('scalar/test_psnr', avg_psnr, e + 1) 191 | 192 | print("===> {}\tEpoch {} Training Complete: Avg. Loss: {:.6f} PSNR:{:.3f} best_psnr:{:.3f} best_epoch:{}".format( 193 | time.ctime(), e+1, epoch_meter.value()[0], avg_psnr, best_psnr, best_epoch)) 194 | # run validation set every epoch 195 | # eval_loss = validate(args, eval_loader, net, L1_loss) 196 | # tensorboard visualization 197 | writer.add_scalar('scalar/avg_epoch_loss', epoch_meter.value()[0], e + 1) 198 | # writer.add_scalar('scalar/avg_validation_loss', eval_loss, e + 1) 199 | # save model weights at checkpoints every 10 epochs 200 | if (e + 1) % 5 == 0: 201 | save_checkpoint(args, net, e+1, traintime) 202 | 203 | ## Save the testing results 204 | 205 | print('===> Start testing') 206 | model_name = './checkpoints/' + traintime +'/' + "_" + args.model_title + "_ckpt_epoch_" + str(best_epoch) + ".pth" 207 | with torch.no_grad(): 208 | test_number = 0 209 | epoch_meter = meter.AverageValueMeter() 210 | epoch_meter.reset() 211 | # loading model 212 | net = CST(inp_channels=colors, dim=args.n_feats, depths=[4, 4, 4, 4], 213 | num_heads=[6, 6, 6, 6], mlp_ratio=2, scale=args.n_scale) 214 | net.to(device).eval() 215 | state_dict = torch.load(model_name) 216 | net.load_state_dict(state_dict['model']) 217 | 218 | output = [] 219 | for i, (ms, lms, gt) in enumerate(test_loader): 220 | # compute output 221 | ms, lms, gt = ms.to(device), lms.to(device), gt.\ 222 | to(device) 223 | # y = model(ms) 224 | y = net(ms, lms) 225 | y, gt = y.squeeze().cpu().numpy().transpose(1, 2, 0), gt.squeeze().cpu().numpy().transpose(1, 2, 0) 226 | y = y[:gt.shape[0],:gt.shape[1],:] 227 | if i==0: 228 | indices = quality_assessment(gt, y, data_range=1., ratio=4) 229 | else: 230 | indices = sum_dict(indices, quality_assessment(gt, y, data_range=1., ratio=4)) 231 | output.append(y) 232 | test_number += 1 233 | for index in indices: 234 | indices[index] = indices[index] / test_number 235 | 236 | 237 | save_dir = result_path + model_title + '.npy' 238 | np.save(save_dir, output) 239 | print("Test finished, test results saved to .npy file at ", save_dir) 240 | print(indices) 241 | QIstr = model_title+'_'+str(time.ctime()) + ".txt" 242 | json.dump(indices, open(QIstr, 'w')) 243 | 244 | 245 | def sum_dict(a, b): 246 | temp = dict() 247 | for key in a.keys()| b.keys(): 248 | temp[key] = sum([d.get(key, 0) for d in (a, b)]) 249 | return temp 250 | 251 | 252 | def adjust_learning_rate(start_lr, optimizer, epoch): 253 | """Sets the learning rate to the initial LR decayed by 2 every 150 epochs""" 254 | lr = start_lr * (0.5 ** (epoch // 150)) 255 | for param_group in optimizer.param_groups: 256 | param_group['lr'] = lr 257 | 258 | 259 | def validate(args, loader, model, criterion): 260 | device = torch.device("cuda" if args.cuda else "cpu") 261 | # switch to evaluate mode 262 | model.eval() 263 | epoch_meter = meter.AverageValueMeter() 264 | epoch_meter.reset() 265 | with torch.no_grad(): 266 | for i, (ms, lms, gt) in enumerate(loader): 267 | ms, lms, gt = ms.to(device), lms.to(device), gt.to(device) 268 | y = model(ms, lms) 269 | loss = criterion(y, gt) 270 | epoch_meter.add(loss.item()) 271 | 272 | # back to training mode 273 | model.train() 274 | return epoch_meter.value()[0] 275 | 276 | 277 | def test(args): 278 | if args.dataset_name=='Cave': 279 | colors = 31 280 | elif args.dataset_name=='Pavia': 281 | colors = 102 282 | elif args.dataset_name=='Houston': 283 | colors = 48 284 | else: 285 | colors = 128 286 | test_data_dir = './datasets/' + args.dataset_name + '_x' + str(args.n_scale) + '/' + args.dataset_name + '_test.mat' 287 | result_path = './results/' + args.dataset_name + '_x' + str(args.n_scale) + '/' 288 | model_title = args.model_title+'_x' + str(args.n_scale) 289 | #model_name = './checkpoints/' +'/'+args.dataset_name +'_'+ model_title + "_ckpt_epoch_" + str() + ".pth" 290 | model_name = './model/' +args.dataset_name + '_'+ model_title + ".pth" 291 | device = torch.device("cuda" if args.cuda else "cpu") 292 | print('===> Loading testset') 293 | 294 | test_set = HSTestData(test_data_dir) 295 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False) 296 | print('===> Start testing') 297 | 298 | with torch.no_grad(): 299 | test_number = 0 300 | epoch_meter = meter.AverageValueMeter() 301 | epoch_meter.reset() 302 | # loading model 303 | net = CST(inp_channels=colors, dim=args.n_feats, depths=[4, 4, 4, 4], 304 | num_heads=[6, 6, 6, 6], mlp_ratio=2, scale=args.n_scale) 305 | net.to(device).eval() 306 | state_dict = torch.load(model_name) 307 | net.load_state_dict(state_dict['model']) 308 | 309 | output = [] 310 | for i, (ms, lms, gt) in enumerate(test_loader): 311 | # compute output 312 | ms, lms, gt = ms.to(device), lms.to(device), gt.\ 313 | to(device) 314 | # y = model(ms) 315 | y = net(ms, lms) 316 | y, gt = y.squeeze().cpu().numpy().transpose(1, 2, 0), gt.squeeze().cpu().numpy().transpose(1, 2, 0) 317 | y = y[:gt.shape[0],:gt.shape[1],:] 318 | if i==0: 319 | indices = quality_assessment(gt, y, data_range=1., ratio=4) 320 | else: 321 | indices = sum_dict(indices, quality_assessment(gt, y, data_range=1., ratio=4)) 322 | output.append(y) 323 | test_number += 1 324 | for index in indices: 325 | indices[index] = indices[index] / test_number 326 | 327 | #save_dir = "./test.npy" 328 | save_dir = result_path + model_title + '.npy' 329 | np.save(save_dir, output) 330 | print("Test finished, test results saved to .npy file at ", save_dir) 331 | print(indices) 332 | QIstr = model_title+'_'+str(time.ctime()) + ".txt" 333 | json.dump(indices, open(QIstr, 'w')) 334 | 335 | def save_checkpoint(args, model, epoch, traintime): 336 | device = torch.device("cuda" if args.cuda else "cpu") 337 | model.eval().cpu() 338 | checkpoint_model_dir = './checkpoints/'+traintime+'/' 339 | if not os.path.exists(checkpoint_model_dir): 340 | os.makedirs(checkpoint_model_dir) 341 | ckpt_model_filename = args.model_title + "_ckpt_epoch_" + str(epoch) + ".pth" 342 | ckpt_model_path = os.path.join(checkpoint_model_dir, ckpt_model_filename) 343 | 344 | if torch.cuda.device_count() > 1: 345 | state = {"epoch": epoch, "model": model.module.state_dict()} 346 | else: 347 | state = {"epoch": epoch, "model": model.state_dict()} 348 | torch.save(state, ckpt_model_path) 349 | model.to(device).train() 350 | print("Checkpoint saved to {}".format(ckpt_model_path)) 351 | 352 | 353 | def print_network(net): 354 | num_params = 0 355 | for param in net.parameters(): 356 | num_params += param.numel() 357 | print('Total number of parameters: %d' % num_params) 358 | 359 | 360 | if __name__ == "__main__": 361 | main() 362 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Author : zhwzhong 4 | @License : (C) Copyright 2013-2018, hit 5 | @Contact : zhwzhong.hit@gmail.com 6 | @Software: PyCharm 7 | @File : metrics.py 8 | @Time : 2019/12/4 17:35 9 | @Desc : 10 | """ 11 | import numpy as np 12 | from scipy.signal import convolve2d 13 | from skimage.measure import compare_psnr, compare_ssim 14 | 15 | def compare_ergas(x_true, x_pred, ratio): 16 | """ 17 | Calculate ERGAS, ERGAS offers a global indication of the quality of fused image.The ideal value is 0. 18 | :param x_true: 19 | :param x_pred: 20 | :param ratio: 上采样系数 21 | :return: 22 | """ 23 | x_true, x_pred = img_2d_mat(x_true=x_true, x_pred=x_pred) 24 | sum_ergas = 0 25 | for i in range(x_true.shape[0]): 26 | vec_x = x_true[i] 27 | vec_y = x_pred[i] 28 | err = vec_x - vec_y 29 | r_mse = np.mean(np.power(err, 2)) 30 | tmp = r_mse / (np.mean(vec_x)**2) 31 | sum_ergas += tmp 32 | return (100 / ratio) * np.sqrt(sum_ergas / x_true.shape[0]) 33 | 34 | 35 | def compare_sam(x_true, x_pred): 36 | """ 37 | :param x_true: 高光谱图像:格式:(H, W, C) 38 | :param x_pred: 高光谱图像:格式:(H, W, C) 39 | :return: 计算原始高光谱数据与重构高光谱数据的光谱角相似度 40 | """ 41 | num = 0 42 | sum_sam = 0 43 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 44 | for x in range(x_true.shape[0]): 45 | for y in range(x_true.shape[1]): 46 | tmp_pred = x_pred[x, y].ravel() 47 | tmp_true = x_true[x, y].ravel() 48 | if np.linalg.norm(tmp_true) != 0 and np.linalg.norm(tmp_pred) != 0: 49 | sum_sam += np.arccos( 50 | np.minimum(1, np.inner(tmp_pred, tmp_true) / (np.linalg.norm(tmp_true) * np.linalg.norm(tmp_pred)))) 51 | 52 | num += 1 53 | sam_deg = (sum_sam / num) * 180 / np.pi 54 | return sam_deg 55 | 56 | 57 | def compare_corr(x_true, x_pred): 58 | """ 59 | Calculate the cross correlation between x_pred and x_true. 60 | 求对应波段的相关系数,然后取均值 61 | CC is a spatial measure. 62 | """ 63 | x_true, x_pred = img_2d_mat(x_true=x_true, x_pred=x_pred) 64 | x_true = x_true - np.mean(x_true, axis=1).reshape(-1, 1) 65 | x_pred = x_pred - np.mean(x_pred, axis=1).reshape(-1, 1) 66 | numerator = np.sum(x_true * x_pred, axis=1).reshape(-1, 1) 67 | denominator = np.sqrt(np.sum(x_true * x_true, axis=1) * np.sum(x_pred * x_pred, axis=1)).reshape(-1, 1) 68 | return (numerator / denominator).mean() 69 | 70 | 71 | def img_2d_mat(x_true, x_pred): 72 | """ 73 | # 将三维的多光谱图像转为2位矩阵 74 | :param x_true: (H, W, C) 75 | :param x_pred: (H, W, C) 76 | :return: a matrix which shape is (C, H * W) 77 | """ 78 | h, w, c = x_true.shape 79 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 80 | x_mat = np.zeros((c, h * w), dtype=np.float32) 81 | y_mat = np.zeros((c, h * w), dtype=np.float32) 82 | for i in range(c): 83 | x_mat[i] = x_true[:, :, i].reshape((1, -1)) 84 | y_mat[i] = x_pred[:, :, i].reshape((1, -1)) 85 | return x_mat, y_mat 86 | 87 | 88 | def compare_rmse(x_true, x_pred): 89 | """ 90 | Calculate Root mean squared error 91 | :param x_true: 92 | :param x_pred: 93 | :return: 94 | """ 95 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 96 | return np.linalg.norm(x_true - x_pred) / (np.sqrt(x_true.shape[0] * x_true.shape[1] * x_true.shape[2])) 97 | 98 | 99 | def compare_mpsnr(x_true, x_pred, data_range): 100 | """ 101 | :param x_true: Input image must have three dimension (H, W, C) 102 | :param x_pred: 103 | :return: 104 | """ 105 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 106 | channels = x_true.shape[2] 107 | total_psnr = [compare_psnr(im_true=x_true[:, :, k], im_test=x_pred[:, :, k], data_range=data_range) 108 | for k in range(channels)] 109 | 110 | return np.mean(total_psnr) 111 | 112 | def compare_mpsnr_test(x_true, x_pred, data_range): 113 | """ 114 | :param x_true: Input image must have three dimension (H, W, C) 115 | :param x_pred: 116 | :return: 117 | """ 118 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 119 | print(np.argwhere(np.isnan(x_true))) 120 | print(np.argwhere(np.isnan(x_pred))) 121 | channels = x_true.shape[2] 122 | total_psnr = [compare_psnr(im_true=x_true[:, :, k], im_test=x_pred[:, :, k], data_range=data_range) 123 | for k in range(channels)] 124 | 125 | return np.mean(total_psnr) 126 | 127 | 128 | def compare_mssim(x_true, x_pred, data_range, multidimension): 129 | """ 130 | 131 | :param x_true: 132 | :param x_pred: 133 | :param data_range: 134 | :param multidimension: 135 | :return: 136 | """ 137 | mssim = [compare_ssim(X=x_true[:, :, i], Y=x_pred[:, :, i], data_range=data_range, multidimension=multidimension) 138 | for i in range(x_true.shape[2])] 139 | 140 | return np.mean(mssim) 141 | 142 | 143 | def compare_sid(x_true, x_pred): 144 | """ 145 | SID is an information theoretic measure for spectral similarity and discriminability. 146 | :param x_true: 147 | :param x_pred: 148 | :return: 149 | """ 150 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 151 | N = x_true.shape[2] 152 | err = np.zeros(N) 153 | for i in range(N): 154 | err[i] = abs(np.sum(x_pred[:, :, i] * np.log10((x_pred[:, :, i] + 1e-3) / (x_true[:, :, i] + 1e-3))) + 155 | np.sum(x_true[:, :, i] * np.log10((x_true[:, :, i] + 1e-3) / (x_pred[:, :, i] + 1e-3)))) 156 | return np.mean(err / (x_true.shape[1] * x_true.shape[0])) 157 | 158 | 159 | def compare_appsa(x_true, x_pred): 160 | """ 161 | 162 | :param x_true: 163 | :param x_pred: 164 | :return: 165 | """ 166 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 167 | nom = np.sum(x_true * x_pred, axis=2) 168 | denom = np.linalg.norm(x_true, axis=2) * np.linalg.norm(x_pred, axis=2) 169 | 170 | cos = np.where((nom / (denom + 1e-3)) > 1, 1, (nom / (denom + 1e-3))) 171 | appsa = np.arccos(cos) 172 | return np.sum(appsa) / (x_true.shape[1] * x_true.shape[0]) 173 | 174 | 175 | def compare_mare(x_true, x_pred): 176 | """ 177 | 178 | :param x_true: 179 | :param x_pred: 180 | :return: 181 | """ 182 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 183 | diff = x_true - x_pred 184 | abs_diff = np.abs(diff) 185 | relative_abs_diff = np.divide(abs_diff, x_true + 1) # added epsilon to avoid division by zero. 186 | return np.mean(relative_abs_diff) 187 | 188 | 189 | def img_qi(img1, img2, block_size=8): 190 | N = block_size ** 2 191 | sum2_filter = np.ones((block_size, block_size)) 192 | 193 | img1_sq = img1 * img1 194 | img2_sq = img2 * img2 195 | img12 = img1 * img2 196 | 197 | img1_sum = convolve2d(img1, np.rot90(sum2_filter), mode='valid') 198 | img2_sum = convolve2d(img2, np.rot90(sum2_filter), mode='valid') 199 | img1_sq_sum = convolve2d(img1_sq, np.rot90(sum2_filter), mode='valid') 200 | img2_sq_sum = convolve2d(img2_sq, np.rot90(sum2_filter), mode='valid') 201 | img12_sum = convolve2d(img12, np.rot90(sum2_filter), mode='valid') 202 | 203 | img12_sum_mul = img1_sum * img2_sum 204 | img12_sq_sum_mul = img1_sum * img1_sum + img2_sum * img2_sum 205 | numerator = 4 * (N * img12_sum - img12_sum_mul) * img12_sum_mul 206 | denominator1 = N * (img1_sq_sum + img2_sq_sum) - img12_sq_sum_mul 207 | denominator = denominator1 * img12_sq_sum_mul 208 | quality_map = np.ones(denominator.shape) 209 | index = (denominator1 == 0) & (img12_sq_sum_mul != 0) 210 | quality_map[index] = 2 * img12_sum_mul[index] / img12_sq_sum_mul[index] 211 | index = (denominator != 0) 212 | quality_map[index] = numerator[index] / denominator[index] 213 | return quality_map.mean() 214 | 215 | 216 | def compare_qave(x_true, x_pred, block_size=8): 217 | n_bands = x_true.shape[2] 218 | q_orig = np.zeros(n_bands) 219 | for idim in range(n_bands): 220 | q_orig[idim] = img_qi(x_true[:, :, idim], x_pred[:, :, idim], block_size) 221 | return q_orig.mean() 222 | 223 | 224 | def quality_assessment(x_true, x_pred, data_range, ratio, multi_dimension=False, block_size=8): 225 | """ 226 | 227 | :param multi_dimension: 228 | :param ratio: 229 | :param data_range: 230 | :param x_true: 231 | :param x_pred: 232 | :param block_size 233 | :return: 234 | """ 235 | result = {'MPSNR': compare_mpsnr(x_true=x_true, x_pred=x_pred, data_range=data_range), 236 | 'MSSIM': compare_mssim(x_true=x_true, x_pred=x_pred, data_range=data_range, 237 | multidimension=multi_dimension), 238 | 'ERGAS': compare_ergas(x_true=x_true, x_pred=x_pred, ratio=ratio), 239 | 'SAM': compare_sam(x_true=x_true, x_pred=x_pred), 240 | # 'SID': compare_sid(x_true=x_true, x_pred=x_pred), 241 | 'CrossCorrelation': compare_corr(x_true=x_true, x_pred=x_pred), 242 | 'RMSE': compare_rmse(x_true=x_true, x_pred=x_pred), 243 | # 'APPSA': compare_appsa(x_true=x_true, x_pred=x_pred), 244 | # 'MARE': compare_mare(x_true=x_true, x_pred=x_pred), 245 | # "QAVE": compare_qave(x_true=x_true, x_pred=x_pred, block_size=block_size) 246 | } 247 | return result 248 | 249 | # from scipy import io as sio 250 | # im_out = np.array(sio.loadmat('/home/zhwzhong/PycharmProject/HyperSR/SOAT/HyperSR/SRindices/Chikuse_EDSRViDeCNN_Blocks=9_Feats=256_Loss_H_Real_1_1_X2X2_N5new_BS32_Epo60_epoch_60_Fri_Sep_20_21:38:44_2019.mat')['output']) 251 | # im_gt = np.array(sio.loadmat('/home/zhwzhong/PycharmProject/HyperSR/SOAT/HyperSR/SRindices/Chikusei_test.mat')['gt']) 252 | # 253 | # sum_rmse, sum_sam, sum_psnr, sum_ssim, sum_ergas = [], [], [], [], [] 254 | # for i in range(im_gt.shape[0]): 255 | # print(im_out[i].shape) 256 | # score = quality_assessment(x_pred=im_out[i], x_true=im_gt[i], data_range=1, ratio=4, multi_dimension=False, block_size=8) 257 | # sum_rmse.append(score['RMSE']) 258 | # sum_psnr.append(score['MPSNR']) 259 | # sum_ssim.append(score['MSSIM']) 260 | # sum_sam.append(score['SAM']) 261 | # sum_ergas.append(score['ERGAS']) 262 | # 263 | # print(np.mean(sum_rmse), np.mean(sum_psnr), np.mean(sum_ssim), np.mean(sum_sam)) 264 | -------------------------------------------------------------------------------- /network/CST.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | from common import * 5 | from einops import rearrange 6 | from network.csa import CSA 7 | from timm.models.layers import DropPath, trunc_normal_ 8 | import scipy.io as sio 9 | 10 | class CST(nn.Module): 11 | """SST 12 | Spatial-Spectral Transformer for Hyperspectral Image Denoising 13 | Args: 14 | inp_channels (int, optional): Input channels of HSI. Defaults to 31. 15 | dim (int, optional): Embedding dimension. Defaults to 90. 16 | window_size (int, optional): Window size of non-local spatial attention. Defaults to 8. 17 | depths (list, optional): Number of Transformer block at different layers of network. Defaults to [ 6,6,6,6,6,6]. 18 | num_heads (list, optional): Number of attention heads in different layers. Defaults to [ 6,6,6,6,6,6]. 19 | mlp_ratio (int, optional): Ratio of mlp dim. Defaults to 2. 20 | qkv_bias (bool, optional): Learnable bias to query, key, value. Defaults to True. 21 | qk_scale (_type_, optional): The qk scale in non-local spatial attention. Defaults to None. If it is set to None, the embedding dimension is used to calculate the qk scale. 22 | bias (bool, optional): Defaults to False. 23 | drop_path_rate (float, optional): Stochastic depth rate of drop rate. Defaults to 0.1. 24 | """ 25 | 26 | def __init__(self, 27 | inp_channels=31, 28 | dim=90, 29 | depths=[6, 6, 6, 6, 6, 6], 30 | num_heads=[6, 6, 6, 6, 6, 6], 31 | mlp_ratio=2, 32 | qkv_bias=True, qk_scale=None, 33 | bias=False, 34 | drop_path_rate=0.1, 35 | scale=4 36 | ): 37 | super(CST, self).__init__() 38 | 39 | self.conv_first = nn.Conv2d(inp_channels, dim, 3, 1, 1) # shallow featrure extraction 40 | self.num_layers = depths 41 | self.layers = nn.ModuleList() 42 | print("network depth:", len(self.num_layers)) 43 | 44 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 45 | for i_layer in range(len(self.num_layers)): 46 | layer = Cstage(dim=dim, 47 | depth=depths[i_layer], 48 | num_head=num_heads[i_layer], 49 | mlp_ratio=mlp_ratio, 50 | qkv_bias=qkv_bias, qk_scale=qk_scale, 51 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 52 | bias=bias) 53 | self.layers.append(layer) 54 | 55 | # self.conv_delasta = nn.Conv2d(dim, inp_channels, 3, 1, 1) # reconstruction from features 56 | self.skip_conv = default_conv(inp_channels, dim, 3) 57 | self.upsample = Upsampler(default_conv, scale, dim) 58 | self.tail = default_conv(dim, inp_channels, 3) 59 | self.conv = default_conv(dim,dim,3) 60 | 61 | def forward(self, inp_img, lms): 62 | f1 = self.conv_first(inp_img) 63 | # ff = f1.detach().cpu().numpy() 64 | # outputfile = "f1.mat" 65 | # sio.savemat(outputfile, {'features':ff}) 66 | # print("save successfully") 67 | 68 | x = f1 69 | for i in range(len(self.num_layers)): 70 | x = self.layers[i](x) 71 | x = self.conv(x + f1) 72 | # x = self.conv_delasta(x) + inp_img 73 | x = self.upsample(x) 74 | x = x + self.skip_conv(lms) 75 | x = self.tail(x) 76 | return x 77 | 78 | 79 | class Cstage(nn.Module): 80 | def __init__(self, 81 | dim=90, 82 | split_size=(2,16), 83 | depth=6, 84 | num_head=6, 85 | mlp_ratio=2, 86 | qkv_bias=True, qk_scale=None, 87 | drop_path=0.1, 88 | bias=False): 89 | super(Cstage, self).__init__() 90 | self.layers1 = nn.ModuleList() 91 | self.layers2 = ResAttentionBlock(default_conv, dim, 1, res_scale=0.1) 92 | self.depth = depth 93 | for i_layer in range(depth): 94 | self.layers1.append(CSMA(dim=dim, 95 | input_resolution=(32, 32), 96 | num_heads=num_head, 97 | drop_path=drop_path[i_layer], 98 | split_size=split_size, 99 | shift_size=[0,0] if (i_layer % 2 == 0) else [split_size[0]//2, split_size[1]//2], 100 | mlp_ratio=mlp_ratio, 101 | attn_drop=0, 102 | qkv_bias=qkv_bias, qk_scale=qk_scale, bias=bias)) 103 | self.conv = nn.Conv2d(dim, dim, 1) 104 | 105 | def forward(self, x): 106 | x1 = x 107 | for i in range(self.depth): 108 | x1 = self.layers1[i](x1) 109 | x2 = self.layers2(x) 110 | out = self.conv(x1) + x2 111 | out = x + out 112 | return out 113 | 114 | 115 | class CSE(nn.Module): 116 | """global spectral attention (CSE) 117 | Args: 118 | dim (int): Number of input channels. 119 | num_heads (int): Number of attention heads 120 | bias (bool): If True, add a learnable bias to projection 121 | """ 122 | 123 | def __init__(self, dim, num_heads, bias, k=0.5, sr_ratio=2): 124 | super(CSE, self).__init__() 125 | self.num_heads = num_heads 126 | self.k = int(k * dim) 127 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 128 | # self.qkv = nn.Conv2d(dim, k*3, kernel_size=1, bias=bias) 129 | self.sr_ratio = sr_ratio 130 | self.v = nn.Conv2d(dim, self.k, kernel_size=1, bias=bias) 131 | self.qk = BSConvU(dim, 2 * self.k, kernel_size=sr_ratio, stride=sr_ratio, padding=0) 132 | self.project_out = nn.Conv2d(self.k, dim, kernel_size=1, bias=bias) 133 | self.norm = nn.LayerNorm(dim) 134 | 135 | def forward(self, x): 136 | b, c, h, w = x.shape 137 | qk = self.qk(x) 138 | q, k = qk.chunk(2, dim=1) # b self.k h/s w/s 139 | v = self.v(x) # b k h w 140 | q = q.reshape(b, self.num_heads, self.k // self.num_heads, -1) 141 | k = k.reshape(b, self.num_heads, self.k // self.num_heads, -1) 142 | v = v.reshape(b, self.num_heads, self.k // self.num_heads, -1) # b k h w 143 | 144 | q = torch.nn.functional.normalize(q, dim=-1) 145 | k = torch.nn.functional.normalize(k, dim=-1) 146 | attn = (q @ k.transpose(-2, -1)) * self.temperature 147 | attn = attn.softmax(dim=-1) 148 | 149 | out = (attn @ v) 150 | 151 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 152 | out = self.project_out(out) 153 | return out 154 | 155 | def flops(self, patchresolution): 156 | flops = 0 157 | H, W, C = patchresolution 158 | flops += H * C * W * C 159 | flops += C * C * H * W 160 | return flops 161 | 162 | 163 | class FeedForward(nn.Module): 164 | def __init__(self, dim, ffn_expansion_factor=2.66, bias=False): 165 | super(FeedForward, self).__init__() 166 | 167 | hidden_features = int(dim*ffn_expansion_factor) 168 | 169 | self.bsconv = BSConvU(dim, hidden_features*2, kernel_size=3, stride=1, padding=1, bias=bias) 170 | 171 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 172 | 173 | def forward(self, x): 174 | x1, x2 = self.bsconv(x).chunk(2, dim=1) 175 | x = F.gelu(x1) * x2 176 | x = self.project_out(x) 177 | return x 178 | 179 | class BSConvU(torch.nn.Module): 180 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, 181 | dilation=1, bias=True, padding_mode="zeros", with_ln=False, bn_kwargs=None): 182 | super().__init__() 183 | self.with_ln = with_ln 184 | # check arguments 185 | if bn_kwargs is None: 186 | bn_kwargs = {} 187 | 188 | # pointwise 189 | self.pw = torch.nn.Conv2d( 190 | in_channels=in_channels, 191 | out_channels=out_channels, 192 | kernel_size=1, 193 | stride=1, 194 | padding=0, 195 | dilation=1, 196 | groups=1, 197 | bias=False, 198 | ) 199 | 200 | # depthwise 201 | self.dw = torch.nn.Conv2d( 202 | in_channels=out_channels, 203 | out_channels=out_channels, 204 | kernel_size=kernel_size, 205 | stride=stride, 206 | padding=padding, 207 | dilation=dilation, 208 | groups=out_channels, 209 | bias=bias, 210 | padding_mode=padding_mode, 211 | ) 212 | 213 | def forward(self, fea): 214 | fea = self.pw(fea) 215 | fea = self.dw(fea) 216 | return fea 217 | 218 | class CSMA(nn.Module): 219 | def __init__(self, dim, input_resolution=[32,32], num_heads=6, drop_path=0.0, split_size=[7, 7], shift_size=[0,0], 220 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., act_layer=nn.GELU, bias=False): 221 | super(CSMA, self).__init__() 222 | self.dim = dim 223 | self.input_resolution = input_resolution 224 | self.num_heads = num_heads 225 | self.mlp_ratio = mlp_ratio 226 | 227 | self.norm1 = nn.LayerNorm(dim) 228 | self.norm2 = nn.LayerNorm(dim) 229 | 230 | 231 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 232 | self.ffn = FeedForward(dim) 233 | 234 | self.attns = CSA( 235 | dim, 236 | input_resolution=input_resolution, 237 | num_heads=num_heads, 238 | split_size=split_size, 239 | shift_size=shift_size, 240 | qkv_bias=qkv_bias, 241 | attn_drop=attn_drop, 242 | proj_drop=drop) 243 | self.spectral_attn = CSE(dim, num_heads, bias) 244 | 245 | def forward(self, x): 246 | B, C, H, W = x.shape 247 | x = x.flatten(2).transpose(1, 2) 248 | shortcut = x 249 | x = self.norm1(x) 250 | x = self.attns(x, (H,W)) 251 | 252 | x = x.view(B, H * W, C) 253 | x = x.transpose(1, 2).view(B, C, H, W) 254 | x = self.spectral_attn(x) # global spectral attention 255 | 256 | x = x.flatten(2).transpose(1, 2) 257 | # FFN 258 | x = shortcut + self.drop_path(x) 259 | x = x + self.drop_path(self.ffn(self.norm2(x).transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)) 260 | 261 | x = x.transpose(1, 2).view(B, C, H, W) 262 | return x 263 | 264 | -------------------------------------------------------------------------------- /network/csa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | 5 | from timm.models.layers import DropPath, trunc_normal_ 6 | from einops.layers.torch import Rearrange 7 | from einops import rearrange 8 | 9 | import math 10 | import numpy as np 11 | 12 | 13 | class CSA(nn.Module): 14 | """ Regular Cross Aggregation Transformer Block. 15 | Args: 16 | dim (int): Number of input channels. 17 | reso (int): Input resolution. 18 | num_heads (int): Number of attention heads. 19 | split_size (tuple(int)): Height and Width of the regular rectangle window (regular-Rwin). 20 | shift_size (tuple(int)): Shift size for regular-Rwin. 21 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 22 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 23 | qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set. 24 | drop (float): Dropout rate. Default: 0.0 25 | attn_drop (float): Attention dropout rate. Default: 0.0 26 | drop_path (float): Stochastic depth rate. Default: 0.0 27 | act_layer (nn.Module): Activation layer. Default: nn.GELU 28 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm 29 | """ 30 | 31 | def __init__(self, dim, 32 | input_resolution, 33 | num_heads, 34 | split_size=[2,4], 35 | shift_size=[1,2], 36 | qkv_bias=True, 37 | qk_scale=None, 38 | attn_drop=0., 39 | proj_drop=0.,): 40 | super().__init__() 41 | self.dim = dim 42 | self.num_heads = num_heads 43 | self.input_resolution = input_resolution 44 | self.num_heads = num_heads 45 | self.split_size = split_size 46 | self.shift_size = shift_size 47 | 48 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 49 | 50 | assert 0 <= self.shift_size[0] < self.split_size[0], "shift_size must in 0-split_size0" 51 | assert 0 <= self.shift_size[1] < self.split_size[1], "shift_size must in 0-split_size1" 52 | 53 | self.proj = nn.Linear(dim, dim) 54 | self.attn_drop = nn.Dropout(attn_drop) 55 | 56 | self.attns = nn.ModuleList([ 57 | Attention_regular( 58 | dim, resolution=self.input_resolution, idx=i, 59 | split_size=split_size, num_heads=num_heads // 2, dim_out=dim // 2, 60 | qk_scale=qk_scale, attn_drop=attn_drop, position_bias=True) 61 | for i in range(2)]) 62 | 63 | self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) # DW Conv 64 | 65 | if self.shift_size[0] > 0 or self.shift_size[1] > 0: 66 | attn_mask = self.calculate_mask(self.input_resolution[0], self.input_resolution[1]) 67 | self.register_buffer("attn_mask_0", attn_mask[0]) 68 | self.register_buffer("attn_mask_1", attn_mask[1]) 69 | else: 70 | attn_mask = None 71 | 72 | self.register_buffer("attn_mask_0", None) 73 | self.register_buffer("attn_mask_1", None) 74 | 75 | def calculate_mask(self, H, W): 76 | # The implementation builds on Swin Transformer code https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 77 | # calculate attention mask for Rwin 78 | img_mask_0 = torch.zeros([1, H, W, 1]) # 1 H W 1 idx=0 79 | img_mask_1 = torch.zeros([1, H, W, 1]) # 1 H W 1 idx=1 80 | h_slices_0 = (slice(0, -self.split_size[0]), 81 | slice(-self.split_size[0], -self.shift_size[0]), 82 | slice(-self.shift_size[0], None)) 83 | w_slices_0 = (slice(0, -self.split_size[1]), 84 | slice(-self.split_size[1], -self.shift_size[1]), 85 | slice(-self.shift_size[1], None)) 86 | 87 | h_slices_1 = (slice(0, -self.split_size[1]), 88 | slice(-self.split_size[1], -self.shift_size[1]), 89 | slice(-self.shift_size[1], None)) 90 | w_slices_1 = (slice(0, -self.split_size[0]), 91 | slice(-self.split_size[0], -self.shift_size[0]), 92 | slice(-self.shift_size[0], None)) 93 | cnt = 0 94 | for h in h_slices_0: 95 | for w in w_slices_0: 96 | img_mask_0[:, h, w, :] = cnt 97 | cnt += 1 98 | cnt = 0 99 | for h in h_slices_1: 100 | for w in w_slices_1: 101 | img_mask_1[:, h, w, :] = cnt 102 | cnt += 1 103 | 104 | # calculate mask for H-Shift 105 | img_mask_0 = img_mask_0.view(1, H // self.split_size[0], self.split_size[0], W // self.split_size[1], 106 | self.split_size[1], 1) 107 | img_mask_0 = img_mask_0.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.split_size[0], self.split_size[1], 108 | 1) # nW, sw[0], sw[1], 1 109 | mask_windows_0 = img_mask_0.view(-1, self.split_size[0] * self.split_size[1]) 110 | attn_mask_0 = mask_windows_0.unsqueeze(1) - mask_windows_0.unsqueeze(2) 111 | attn_mask_0 = attn_mask_0.masked_fill(attn_mask_0 != 0, float(-100.0)).masked_fill(attn_mask_0 == 0, float(0.0)) 112 | 113 | # calculate mask for V-Shift 114 | img_mask_1 = img_mask_1.view(1, H // self.split_size[1], self.split_size[1], W // self.split_size[0], 115 | self.split_size[0], 1) 116 | img_mask_1 = img_mask_1.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.split_size[0], self.split_size[1], 117 | 1) # nW, sw[1], sw[0], 1 118 | mask_windows_1 = img_mask_1.view(-1, self.split_size[1] * self.split_size[0]) 119 | attn_mask_1 = mask_windows_1.unsqueeze(1) - mask_windows_1.unsqueeze(2) 120 | attn_mask_1 = attn_mask_1.masked_fill(attn_mask_1 != 0, float(-100.0)).masked_fill(attn_mask_1 == 0, float(0.0)) 121 | 122 | return attn_mask_0, attn_mask_1 123 | 124 | def forward(self, x, x_size): 125 | """ 126 | Input: x: (B, H*W, C), x_size: (H, W) 127 | Output: x: (B, H*W, C) 128 | """ 129 | 130 | H, W = x_size 131 | B, L, C = x.shape 132 | assert L == H * W, "flatten img_tokens has wrong size" 133 | 134 | qkv = self.qkv(x).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # 3, B, HW, C 135 | # v without partition 136 | v = qkv[2].transpose(-2, -1).contiguous().view(B, C, H, W) 137 | 138 | if self.shift_size[0] > 0 or self.shift_size[1] > 0: 139 | qkv = qkv.view(3, B, H, W, C) 140 | # H-Shift 141 | qkv_0 = torch.roll(qkv[:, :, :, :, :C // 2], shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(2, 3)) 142 | qkv_0 = qkv_0.view(3, B, L, C // 2) 143 | # V-Shift 144 | qkv_1 = torch.roll(qkv[:, :, :, :, C // 2:], shifts=(-self.shift_size[1], -self.shift_size[0]), dims=(2, 3)) 145 | qkv_1 = qkv_1.view(3, B, L, C // 2) 146 | 147 | if self.input_resolution[0] != H or self.input_resolution[1] != W: 148 | mask_tmp = self.calculate_mask(H, W) 149 | # H-Rwin 150 | x1_shift = self.attns[0](qkv_0, H, W, mask=mask_tmp[0].to(x.device)) 151 | # V-Rwin 152 | x2_shift = self.attns[1](qkv_1, H, W, mask=mask_tmp[1].to(x.device)) 153 | 154 | else: 155 | # H-Rwin 156 | x1_shift = self.attns[0](qkv_0, H, W, mask=self.attn_mask_0) 157 | # V-Rwin 158 | x2_shift = self.attns[1](qkv_1, H, W, mask=self.attn_mask_1) 159 | 160 | x1 = torch.roll(x1_shift, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)) 161 | x2 = torch.roll(x2_shift, shifts=(self.shift_size[1], self.shift_size[0]), dims=(1, 2)) 162 | x1 = x1.view(B, L, C // 2).contiguous() 163 | x2 = x2.view(B, L, C // 2).contiguous() 164 | # Concat 165 | attened_x = torch.cat([x1, x2], dim=2) 166 | else: 167 | # V-Rwin 168 | x1 = self.attns[0](qkv[:, :, :, :C // 2], H, W).view(B, L, C // 2).contiguous() 169 | # H-Rwin 170 | x2 = self.attns[1](qkv[:, :, :, C // 2:], H, W).view(B, L, C // 2).contiguous() 171 | # Concat 172 | attened_x = torch.cat([x1, x2], dim=2) 173 | 174 | # Locality Complementary Module 175 | lcm = self.get_v(v) 176 | lcm = lcm.permute(0, 2, 3, 1).contiguous().view(B, L, C) 177 | 178 | attened_x = attened_x + lcm 179 | 180 | attened_x = self.proj(attened_x) 181 | x = x + attened_x 182 | 183 | return x 184 | 185 | 186 | class Attention_regular(nn.Module): 187 | """ Regular Rectangle-Window (regular-Rwin) self-attention with dynamic relative position bias. 188 | It supports both of shifted and non-shifted window. 189 | Args: 190 | dim (int): Number of input channels. 191 | resolution (int): Input resolution. 192 | idx (int): The identix of V-Rwin and H-Rwin, 0 is H-Rwin, 1 is Vs-Rwin. (different order from Attention_axial) 193 | split_size (tuple(int)): Height and Width of the regular rectangle window (regular-Rwin). 194 | dim_out (int | None): The dimension of the attention output. Default: None 195 | num_heads (int): Number of attention heads. Default: 6 196 | attn_drop (float): Dropout ratio of attention weight. Default: 0.0 197 | proj_drop (float): Dropout ratio of output. Default: 0.0 198 | qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set 199 | position_bias (bool): The dynamic relative position bias. Default: True 200 | """ 201 | def __init__(self, 202 | dim, 203 | resolution, 204 | idx, 205 | split_size=[2,4], 206 | dim_out=None, 207 | num_heads=6, 208 | attn_drop=0., 209 | qk_scale=None, 210 | position_bias=True): 211 | super().__init__() 212 | self.dim = dim 213 | self.dim_out = dim_out or dim 214 | self.resolution = resolution 215 | self.split_size = split_size 216 | self.num_heads = num_heads 217 | self.idx = idx 218 | self.position_bias = position_bias 219 | 220 | head_dim = dim // num_heads 221 | self.scale = qk_scale or head_dim ** -0.5 222 | if idx == -1: 223 | H_sp, W_sp = self.resolution, self.resolution 224 | elif idx == 0: 225 | H_sp, W_sp = self.split_size[0], self.split_size[1] 226 | elif idx == 1: 227 | W_sp, H_sp = self.split_size[0], self.split_size[1] 228 | else: 229 | print ("ERROR MODE", idx) 230 | exit(0) 231 | self.H_sp = H_sp 232 | self.W_sp = W_sp 233 | 234 | if self.position_bias: 235 | self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False) 236 | # generate mother-set 237 | position_bias_h = torch.arange(1 - self.H_sp, self.H_sp) 238 | position_bias_w = torch.arange(1 - self.W_sp, self.W_sp) 239 | biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) 240 | biases = biases.flatten(1).transpose(0, 1).contiguous().float() 241 | self.register_buffer('rpe_biases', biases) 242 | 243 | # get pair-wise relative position index for each token inside the window 244 | coords_h = torch.arange(self.H_sp) 245 | coords_w = torch.arange(self.W_sp) 246 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) 247 | coords_flatten = torch.flatten(coords, 1) 248 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 249 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 250 | relative_coords[:, :, 0] += self.H_sp - 1 251 | relative_coords[:, :, 1] += self.W_sp - 1 252 | relative_coords[:, :, 0] *= 2 * self.W_sp - 1 253 | relative_position_index = relative_coords.sum(-1) 254 | self.register_buffer('relative_position_index', relative_position_index) 255 | 256 | self.attn_drop = nn.Dropout(attn_drop) 257 | self.pool = nn.AdaptiveAvgPool2d((self.H_sp, self.W_sp)) 258 | self.pool2 = nn.AdaptiveMaxPool2d((self.H_sp, self.W_sp)) 259 | 260 | def im2win(self, x, H, W): 261 | B, N, C = x.shape 262 | x = x.transpose(-2,-1).contiguous().view(B, C, H, W) 263 | x = img2windows(x, self.H_sp, self.W_sp) 264 | x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous() 265 | return x # -1, heads, H_s* W_s, C // self.num_heads 266 | 267 | def forward(self, qkv, H, W, mask=None): 268 | """ 269 | Input: qkv: (B, 3*L, C), H, W, mask: (B, N, N), N is the window size 270 | Output: x (B, H, W, C) 271 | """ 272 | q,k,v = qkv[0], qkv[1], qkv[2] #B L C 273 | 274 | B, L, C = q.shape 275 | assert L == H * W, "flatten img_tokens has wrong size" 276 | 277 | # partition the q,k,v, image to window 278 | 279 | q1 = q.transpose(-2, -1).view(B, C, H, W) # B, C, H_s, W_s 280 | q1 = self.pool(q1[:, :C//2, :, :]) 281 | q2 = q.transpose(-2, -1).view(B, C, H, W) 282 | q2 = self.pool2(q2[:, C//2:, :, :]) 283 | q = torch.cat([q1,q2],dim=1) 284 | q = q.reshape(B, self.num_heads, C//self.num_heads, self.H_sp, self.W_sp).flatten(3).transpose(-2, -1) 285 | q = q.repeat(H*W //(self.H_sp*self.W_sp),1,1,1) # -1, heads, H_s* W_s, C // self.num_heads 286 | k = self.im2win(k, H, W) # -1, heads, H_s* W_s, C // self.num_heads 287 | v = self.im2win(v, H, W) 288 | 289 | q = q * self.scale 290 | attn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N 291 | 292 | # calculate drpe 293 | if self.position_bias: 294 | pos = self.pos(self.rpe_biases) 295 | # select position bias 296 | relative_position_bias = pos[self.relative_position_index.view(-1)].view( 297 | self.H_sp * self.W_sp, self.H_sp * self.W_sp, -1) 298 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 299 | attn = attn + relative_position_bias.unsqueeze(0) 300 | 301 | N = attn.shape[3] 302 | 303 | # use mask for shift window 304 | if mask is not None: 305 | nW = mask.shape[0] 306 | attn = attn.view(B, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 307 | attn = attn.view(-1, self.num_heads, N, N) 308 | attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype) 309 | attn = self.attn_drop(attn) 310 | 311 | x = (attn @ v) 312 | x = x.transpose(1, 2).reshape(-1, self.H_sp* self.W_sp, C) # B head N N @ B head N C 313 | 314 | # merge the window, window to image 315 | x = windows2img(x, self.H_sp, self.W_sp, H, W) # B H' W' C 316 | 317 | return x 318 | 319 | 320 | class DynamicPosBias(nn.Module): 321 | # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py 322 | """ Dynamic Relative Position Bias. 323 | Args: 324 | dim (int): Number of input channels. 325 | num_heads (int): Number of attention heads. 326 | residual (bool): If True, use residual strage to connect conv. 327 | """ 328 | def __init__(self, dim, num_heads, residual): 329 | super().__init__() 330 | self.residual = residual 331 | self.num_heads = num_heads 332 | self.pos_dim = dim // 4 333 | self.pos_proj = nn.Linear(2, self.pos_dim) 334 | self.pos1 = nn.Sequential( 335 | nn.LayerNorm(self.pos_dim), 336 | nn.ReLU(inplace=True), 337 | nn.Linear(self.pos_dim, self.pos_dim), 338 | ) 339 | self.pos2 = nn.Sequential( 340 | nn.LayerNorm(self.pos_dim), 341 | nn.ReLU(inplace=True), 342 | nn.Linear(self.pos_dim, self.pos_dim) 343 | ) 344 | self.pos3 = nn.Sequential( 345 | nn.LayerNorm(self.pos_dim), 346 | nn.ReLU(inplace=True), 347 | nn.Linear(self.pos_dim, self.num_heads) 348 | ) 349 | 350 | def forward(self, biases): 351 | if self.residual: 352 | pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads 353 | pos = pos + self.pos1(pos) 354 | pos = pos + self.pos2(pos) 355 | pos = self.pos3(pos) 356 | else: 357 | pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) 358 | return pos 359 | 360 | 361 | def img2windows(img, H_sp, W_sp): 362 | """ 363 | Input: Image (B, C, H, W) 364 | Output: Window Partition (B', N, C) 365 | """ 366 | B, C, H, W = img.shape 367 | img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 368 | img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C) 369 | return img_perm 370 | 371 | 372 | def windows2img(img_splits_hw, H_sp, W_sp, H, W): 373 | """ 374 | Input: Window Partition (B', N, C) 375 | Output: Image (B, H, W, C) 376 | """ 377 | B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) 378 | 379 | img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) 380 | img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 381 | return img -------------------------------------------------------------------------------- /test_demo.sh: -------------------------------------------------------------------------------- 1 | python main_CST.py test --model_title "CST" --dataset "Chikusei" --n_scale 4 --gpus "0" 2 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import numpy as np 3 | import torch 4 | import cv2 5 | from torch.utils.data import DataLoader 6 | 7 | def data_augmentation(label, mode=0): 8 | if mode == 0: 9 | # original 10 | return label 11 | elif mode == 1: 12 | # flip up and down 13 | return np.flipud(label) 14 | elif mode == 2: 15 | # rotate counterwise 90 degree 16 | return np.rot90(label) 17 | elif mode == 3: 18 | # rotate 90 degree and flip up and down 19 | return np.flipud(np.rot90(label)) 20 | elif mode == 4: 21 | # rotate 180 degree 22 | return np.rot90(label, k=2) 23 | elif mode == 5: 24 | # rotate 180 degree and flip 25 | return np.flipud(np.rot90(label, k=2)) 26 | elif mode == 6: 27 | # rotate 270 degree 28 | return np.rot90(label, k=3) 29 | elif mode == 7: 30 | # rotate 270 degree and flip 31 | return np.flipud(np.rot90(label, k=3)) 32 | 33 | 34 | # rescale every channel to between 0 and 1 35 | def channel_scale(img): 36 | eps = 1e-5 37 | max_list = np.max((np.max(img, axis=0)), axis=0) 38 | min_list = np.min((np.min(img, axis=0)), axis=0) 39 | output = (img - min_list) / (max_list - min_list + eps) 40 | return output 41 | 42 | 43 | # up sample before feeding into network 44 | def upsample(img, ratio): 45 | [h, w, _] = img.shape 46 | return cv2.resize(img, (ratio*h, ratio*w), interpolation=cv2.INTER_CUBIC) 47 | 48 | 49 | def bicubic_downsample(img, ratio): 50 | [h, w, _] = img.shape 51 | new_h, new_w = int(ratio * h), int(ratio * w) 52 | return cv2.resize(img, (new_h, new_w), interpolation=cv2.INTER_CUBIC) 53 | 54 | 55 | def wald_downsample(data, ratio): 56 | [h, w, c] = data.shape 57 | out = [] 58 | for i in range(c): 59 | dst = cv2.GaussianBlur(data[:, :, i], (7, 7), 0) 60 | dst = dst[0:h:ratio, 0:w:ratio, np.newaxis] 61 | out.append(dst) 62 | out = np.concatenate(out, axis=2) 63 | return out 64 | 65 | 66 | def save_result(result_dir, out): 67 | out = out.numpy().transpose((0, 2, 3, 1)) 68 | sio.savemat(result_dir, {'output': out}) 69 | 70 | 71 | def sam_loss(y, ref): 72 | (b, ch, h, w) = y.size() 73 | tmp1 = y.view(b, ch, h * w).transpose(1, 2) 74 | tmp2 = ref.view(b, ch, h * w) 75 | sam = torch.bmm(tmp1, tmp2) 76 | idx = torch.arange(0, h * w, out=torch.LongTensor()) 77 | sam = sam[:, idx, idx].view(b, h, w) 78 | norm1 = torch.norm(y, 2, 1) 79 | norm2 = torch.norm(ref, 2, 1) 80 | sam = torch.div(sam, (norm1 * norm2)) 81 | sam = torch.sum(sam) / (b * h * w) 82 | return sam 83 | 84 | 85 | def extract_RGB(y): 86 | # take 4-2-1 band (R-G-B) for WV-3 87 | R = torch.unsqueeze(torch.mean(y[:, 4:8, :, :], 1), 1) 88 | G = torch.unsqueeze(torch.mean(y[:, 2:4, :, :], 1), 1) 89 | B = torch.unsqueeze(torch.mean(y[:, 0:2, :, :], 1), 1) 90 | y_RGB = torch.cat((R, G, B), 1) 91 | return y_RGB 92 | 93 | 94 | def extract_edge(data): 95 | N = data.shape[0] 96 | out = np.zeros_like(data) 97 | for i in range(N): 98 | if len(data.shape) == 3: 99 | out[i, :, :] = data[i, :, :] - cv2.boxFilter(data[i, :, :], -1, (5, 5)) 100 | else: 101 | out[i, :, :, :] = data[i, :, :, :] - cv2.boxFilter(data[i, :, :, :], -1, (5, 5)) 102 | return out 103 | 104 | 105 | def normalize_batch(batch): 106 | # normalize using imagenet mean and std 107 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(-1, 1, 1).cuda() 108 | std = torch.Tensor([0.229, 0.224, 0.225]).view(-1, 1, 1).cuda() 109 | return (batch - mean) / std 110 | 111 | 112 | def add_channel(rgb): 113 | # initialize other channels using the average of RGB from VGG 114 | R = torch.unsqueeze(y[:, 0, :, :], 1) 115 | G = torch.unsqueeze(y[:, 1, :, :], 1) 116 | B = torch.unsqueeze(y[:, 2, :, :], 1) 117 | all_channel = torch.cat((B, B, G, G, R, R, R, R), 1) 118 | return all_channel 119 | 120 | 121 | # from LapSRN 122 | class L1_Charbonnier_loss(torch.nn.Module): 123 | """L1 Charbonnierloss.""" 124 | def __init__(self): 125 | super(L1_Charbonnier_loss, self).__init__() 126 | self.eps = 1e-6 127 | 128 | def forward(self, X, Y): 129 | diff = torch.add(X, -Y) 130 | error = torch.sqrt(diff * diff + self.eps) 131 | loss = torch.sum(error) 132 | return loss 133 | 134 | 135 | --------------------------------------------------------------------------------