├── IEEE2018_rand.xlsx ├── README.md ├── args_parser.py ├── data ├── IEEE2018_SRF.xls └── __init__.py ├── data_loader.py ├── loss.py ├── main.py ├── metrics.py ├── models └── MCIFNet.py ├── test.py ├── train.py ├── utils.py └── validate.py /IEEE2018_rand.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chunyuzhu/MCIFNet/52982e7483d6cec5f3a6e09eeb740c1737a84525/IEEE2018_rand.xlsx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | MCIFNet for hyperspectral and multispectral image fusion. 2 | This is the code of Chunyu Zhu et al. , “Mamba Collaborative Implicit Neural Representation for Hyperspectral and MultispectralRemote Sensing Image Fusion,” IEEE Transactions on Geoscience and Remote Sensing (TGRS), 2025. 3 | If you use this code, please cite the following article: 4 | 5 | 1. Mamba Collaborative Implicit Neural Representation for Hyperspectral and MultispectralRemote Sensing Image Fusion. IEEE Transactions on Geoscience and Remote Sensing[J] 6 | 2. A spatial-frequency dual-domain implicit guidance method for hyperspectral and multispectral remote sensing image fusion based on Kolmogorov–Arnold Network. Information Fusion[J] 7 | 3. QIS-GAN: A Lightweight Adversarial Network with Quadtree Implicit Sampling for Multispectral and Hyperspectral Image Fusion.IEEE Transactions on Geoscience and Remote Sensing[J] 8 | 4. An Implicit Transformer-based Fusion Method for Hyperspectral and Multispectral Remote Sensing Image. International Journal of Applied Earth Observation and Geoinformation[J] 9 | 5. An Adaptive Multi-perceptual Implicit Sampling for Hyperspectral and Multispectral Remote Sensing Image Fusion. International Journal of Applied Earth Observation and Geoinformation[J] 10 | 6. Hyperspectral and Multispectral Remote Sensing Image Fusion using SwinGAN with Joint Adaptive SpatialSpectral Gradient Loss Function. International Journal of Digital Earth[J] 11 | 7. MGDIN: Detail Injection Network for HSI and MSI Fusion Based on Multiscale and Global Contextual Features.International Journal of Remote Sensing[J] 12 | -------------------------------------------------------------------------------- /args_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def args_parser(): 4 | parser = argparse.ArgumentParser() 5 | 6 | parser.add_argument('-arch', type=str, default='MCIFNet', 7 | choices=[ 8 | # the proposed method 9 | 'SSRNET', 10 | # these five models are used for comparison experiments 11 | 'TSFN' 12 | 'HyperKite' 13 | '_3DT_Net' 14 | 'PSRTnet' 15 | 'DCT' 16 | 'Fusformer' 17 | 18 | 'MCIFNet' 19 | ]) 20 | 21 | parser.add_argument('-root', type=str, default='./data') 22 | parser.add_argument('-dataset', type=str, default='IEEE2018', 23 | choices=['PaviaU', 'Chikusei','Pavia','IEEE2018','Botswana','realdata']) 24 | parser.add_argument('--scale_ratio', type=float, default=4) 25 | parser.add_argument('--n_bands', type=int, default=0) 26 | parser.add_argument('--n_select_bands', type=int, default=4) 27 | 28 | parser.add_argument('--model_path', type=str, 29 | default='./checkpoints/dataset_arch_4.pkl', 30 | help='path for trained encoder') 31 | parser.add_argument('--train_dir', type=str, default='./data/dataset/train', 32 | help='directory for resized images') 33 | parser.add_argument('--val_dir', type=str, default='./data/dataset/val', 34 | help='directory for resized images') 35 | 36 | # learning settingl 37 | parser.add_argument('--n_epochs', type=int, default=10000, 38 | help='end epoch for training') 39 | # rsicd: 3e-4, ucm: 1e-4, 40 | parser.add_argument('--lr', type=float, default=1e-4) 41 | parser.add_argument('--image_size', type=int, default=128) 42 | 43 | args = parser.parse_args() 44 | return args 45 | -------------------------------------------------------------------------------- /data/IEEE2018_SRF.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chunyuzhu/MCIFNet/52982e7483d6cec5f3a6e09eeb740c1737a84525/data/IEEE2018_SRF.xls -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jul 24 14:01:27 2023 4 | 5 | @author: zxc 6 | """ 7 | import os 8 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 9 | from os import listdir 10 | from torch.nn import functional as F 11 | import cv2 12 | import torch 13 | import numpy as np 14 | import os 15 | import random 16 | import scipy.io as scio 17 | import h5py 18 | import xlrd 19 | 20 | def matlab_style_gauss2D(shape=(5,5),sigma=2): 21 | m,n = [(ss-1.)/2. for ss in shape] 22 | y,x = np.ogrid[-m:m+1,-n:n+1] 23 | h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) ) 24 | h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 25 | sumh = h.sum() 26 | if sumh != 0: 27 | h /= sumh 28 | return h 29 | 30 | def get_spectral_response(xls_path): 31 | # xls_path = os.path.join(self.args.sp_root_path, data_name + '.xls') 32 | if not os.path.exists(xls_path): 33 | raise Exception("spectral response path does not exist") 34 | data = xlrd.open_workbook(xls_path) 35 | table = data.sheets()[0] 36 | num_cols = table.ncols 37 | cols_list = [np.array(table.col_values(i)).reshape(-1,1) for i in range(0,num_cols)] 38 | sp_data = np.concatenate(cols_list, axis=1) 39 | sp_data = sp_data / (sp_data.sum(axis=0)) #normalize the sepctral response 40 | return sp_data 41 | 42 | def downsamplePSF(img,sigma,stride): 43 | def matlab_style_gauss2D(shape=(5,5),sigma=2): 44 | m,n = [(ss-1.)/2. for ss in shape] 45 | y,x = np.ogrid[-m:m+1,-n:n+1] 46 | h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) ) 47 | h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 48 | sumh = h.sum() 49 | if sumh != 0: 50 | h /= sumh 51 | return h 52 | # generate filter same with fspecial('gaussian') function 53 | h = matlab_style_gauss2D((stride,stride),sigma) 54 | if img.ndim == 3: 55 | img_w,img_h,img_c = img.shape 56 | elif img.ndim == 2: 57 | img_c = 1 58 | img_w,img_h = img.shape 59 | img = img.reshape((img_w,img_h,1)) 60 | from scipy import signal 61 | out_img = np.zeros((img_w//(stride), img_h//(stride), img_c)) 62 | for i in range(img_c): 63 | out = signal.convolve2d(img[:,:,i],h,'valid') 64 | out_img[:,:,i] = out[::stride,::stride] 65 | return out_img 66 | def generate_low_HSI( img, scale_factor): 67 | (h, w, c) = img.shape 68 | img_lr = downsamplePSF(img, sigma=2, stride=scale_factor) 69 | return img_lr 70 | 71 | def generate_MSI(img, sp_matrix): 72 | w,h,c = img.shape 73 | # msi_channels = sp_matrix.shape[1] 74 | if sp_matrix.shape[0] == c: 75 | img_msi = np.dot(img.reshape(w*h,c), sp_matrix).reshape(w,h,sp_matrix.shape[1]) 76 | else: 77 | raise Exception("The shape of sp matrix doesnot match the image") 78 | return img_msi 79 | 80 | # root = 'E:/AMIGNet_fuben/data' 81 | # scale_ratio = 4 82 | # size =128 83 | def build_datasets(root, dataset, size, n_select_bands, scale_ratio): 84 | # Imageh preprocessing, normalization for the pretrained resnet 85 | if dataset == 'PaviaU': 86 | img = scio.loadmat(root + '/' + 'PaviaU.mat')['paviaU']*1.0 87 | sp_matrix = get_spectral_response(root + '/' + 'PaviaU'+'_SRF.xls') 88 | elif dataset == 'Pavia': 89 | img = scio.loadmat(root + '/' + 'Pavia.mat')['pavia']*1.0 90 | sp_matrix = get_spectral_response(root + '/' + 'Pavia'+'_SRF.xls') 91 | elif dataset == 'Botswana': 92 | img = scio.loadmat(root + '/' + 'Botswana.mat')['Botswana']*1.0 93 | sp_matrix = get_spectral_response(root + '/' + 'Botswana'+'_SRF.xls') 94 | elif dataset == 'Xiongan': 95 | # img = scio.loadmat(root + '/' + 'Xiongan.mat')['xiongan']*1.0 96 | mat = h5py.File(root + '/' + 'Xiongan.mat') 97 | img = np.transpose(mat['Xiongan']) 98 | sp_matrix = get_spectral_response(root + '/' + 'Xiongan'+'_SRF.xls') 99 | elif dataset == 'IEEE2018': 100 | img = scio.loadmat(root + '/' + 'IEEE2018.mat')['IEEE2018']*1.0 101 | sp_matrix = get_spectral_response(root + '/' + 'IEEE2018'+'_SRF.xls') 102 | elif dataset == 'Chikusei': 103 | mat = h5py.File(root + '/' + 'Chikusei.mat') 104 | img = np.transpose(mat['Chikusei']) 105 | sp_matrix = get_spectral_response(root + '/' + 'Chikusei'+'_SRF.xls') 106 | 107 | 108 | print (img.shape) 109 | max = np.max(img) 110 | min = np.min(img) 111 | img = 255*((img - min) / (max - min + 0.0)) 112 | 113 | # throwing up the edge 114 | # w_edge = img.shape[0]//scale_ratio*scale_ratio-img.shape[0] 115 | # h_edge = img.shape[1]//scale_ratio*scale_ratio-img.shape[1] 116 | # w_edge = -1 if w_edge==0 else w_edge 117 | # h_edge = -1 if h_edge==0 else h_edge 118 | # img = img[:w_edge, :h_edge, :] 119 | 120 | w_edge = img.shape[0]//scale_ratio*scale_ratio 121 | h_edge = img.shape[1]//scale_ratio*scale_ratio 122 | 123 | img = img[:w_edge, :h_edge, :] 124 | 125 | # cropping area 126 | width, height, n_bands = img.shape 127 | w_str = width -192 128 | h_str = 0 129 | w_end = w_str + 192 130 | h_end = h_str + 192 131 | img_copy = img.copy() 132 | 133 | # test sample 134 | # gap_bands = n_bands / (n_select_bands-1.0) 135 | test_ref = img_copy[w_str:w_end, h_str:h_end, :].copy() 136 | test_hr = generate_MSI(test_ref, sp_matrix) 137 | test_lr = generate_low_HSI(test_ref, scale_ratio) 138 | 139 | 140 | img[w_str:w_end,h_str:h_end,:] = 0 141 | train_ref = img 142 | train_hr = generate_MSI(train_ref,sp_matrix) 143 | train_lr = generate_low_HSI(train_ref, scale_ratio) 144 | 145 | 146 | 147 | train_ref = torch.from_numpy(train_ref).permute(2,0,1).unsqueeze(dim=0) 148 | train_lr = torch.from_numpy(train_lr).permute(2,0,1).unsqueeze(dim=0) 149 | train_hr = torch.from_numpy(train_hr).permute(2,0,1).unsqueeze(dim=0) 150 | test_ref = torch.from_numpy(test_ref).permute(2,0,1).unsqueeze(dim=0) 151 | test_lr = torch.from_numpy(test_lr).permute(2,0,1).unsqueeze(dim=0) 152 | test_hr = torch.from_numpy(test_hr).permute(2,0,1).unsqueeze(dim=0) 153 | 154 | return [train_ref, train_lr, train_hr], [test_ref, test_lr, test_hr] 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jul 26 20:58:02 2023 4 | 5 | @author: zxc 6 | """ 7 | 8 | 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | # from torchvision.models.vgg import vgg16 13 | import numpy as np 14 | import cv2 15 | # from scipy.fftpack import fft2 16 | def set_grad(network,requires_grad): 17 | for param in network.parameters(): 18 | param.requires_grad = requires_grad 19 | 20 | def gaussian_pyramid(image_tensor, n_levels=3): 21 | # 创建一个空的图像列表来保存金字塔中的每一层 22 | pyramid = [] 23 | 24 | # 将输入图像复制到列表中作为第一层 25 | pyramid.append(image_tensor) 26 | 27 | # 对于每一层(除了最后一层) 28 | for _ in range(n_levels - 1): 29 | # 计算当前层的高斯金字塔图像 30 | image_tensor = F.avg_pool2d(image_tensor, kernel_size=2, stride=2) 31 | 32 | pyramid.append(image_tensor) 33 | 34 | return pyramid 35 | 36 | class GeneratorLoss(nn.Module): 37 | def __init__(self): 38 | super(GeneratorLoss, self).__init__() 39 | 40 | self.tv_loss = TVLoss() 41 | self.l1loss = nn.L1Loss() 42 | 43 | def forward(self, img_out_labels, out_1, out_2, out_3, target_images): 44 | # Adversarial Loss 45 | adversarial_loss = torch.mean(1 - img_out_labels)#+0.3*torch.mean(1 - edge_out_labels)+0.3*torch.mean(1 - spec_out_labels) 46 | 47 | target_gaussian = gaussian_pyramid(target_images) 48 | 49 | 50 | image_loss = self.l1loss(out_1, target_gaussian[0]) #+ self.l1loss(out_2, target_gaussian[1]) + self.l1loss(out_3, target_gaussian[2]) 51 | 52 | # tv_loss = self.tv_loss(out_images) 53 | # return image_loss + adversarial_loss + 2e-8 * tv_loss 54 | # return image_loss + 2e-8 * tv_loss 55 | return image_loss + adversarial_loss #+ 2e-8 * tv_loss 56 | 57 | class TVLoss(nn.Module): 58 | def __init__(self, tv_loss_weight=1): 59 | super(TVLoss, self).__init__() 60 | self.tv_loss_weight = tv_loss_weight 61 | 62 | def forward(self, x): 63 | batch_size = x.size()[0] 64 | h_x = x.size()[2] 65 | w_x = x.size()[3] 66 | count_h = self.tensor_size(x[:, :, 1:, :]) 67 | count_w = self.tensor_size(x[:, :, :, 1:]) 68 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 69 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 70 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size 71 | 72 | @staticmethod 73 | def tensor_size(t): 74 | return t.size()[1] * t.size()[2] * t.size()[3] 75 | 76 | 77 | if __name__ == "__main__": 78 | g_loss = Con_Edge_Spec_loss() 79 | A = torch.rand(1,144,128,128) 80 | B = torch.rand(1,144,128,128) 81 | print(g_loss(A,B)) 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | import torch.optim 5 | from torch import nn 6 | 7 | from models.MCIFNet import MCIFNet 8 | 9 | from data_loader import build_datasets 10 | 11 | import pandas as pd 12 | from utils import * 13 | # from data_loader import build_datasets 14 | from validate import validate 15 | from train import train 16 | import pdb 17 | import args_parser 18 | from torch.nn import functional as F 19 | 20 | 21 | args = args_parser.args_parser() 22 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 23 | 24 | print (args) 25 | 26 | df = pd.read_excel(args.dataset+'_rand.xlsx', sheet_name='Sheet1') 27 | h_rand = df.iloc[:,0].tolist() 28 | w_rand = df.iloc[:,1].tolist() 29 | def main(): 30 | # Custom dataloader 31 | train_list, test_list = build_datasets(args.root, 32 | args.dataset, 33 | args.image_size, 34 | args.n_select_bands, 35 | args.scale_ratio) 36 | if args.dataset == 'PaviaU': 37 | args.n_bands = 103 38 | 39 | elif args.dataset == 'Pavia': 40 | args.n_bands = 102 41 | elif args.dataset == 'Chikusei': 42 | args.n_bands = 128 43 | elif args.dataset == 'IEEE2018': 44 | args.n_bands = 48 45 | elif args.dataset == 'Botswana': 46 | args.n_bands = 145 47 | # Build the models 48 | if args.arch == 'SSRNET' or args.arch == 'SpatRNET' or args.arch == 'SpecRNET': 49 | model = SSRNET(args.arch, 50 | args.scale_ratio, 51 | args.n_select_bands, 52 | args.n_bands).cuda() 53 | elif args.arch == 'DCT': 54 | model = DCT(n_colors=args.n_bands, upscale_factor=args.scale_ratio, n_feats=180).cuda() 55 | elif args.arch == 'PSRTnet': 56 | model = PSRTnet(args.scale_ratio, 57 | args.n_select_bands, 58 | args.n_bands, 59 | args.image_size).cuda() 60 | elif args.arch == 'HyperKite': 61 | model = HyperKite(args.scale_ratio, 62 | args.n_select_bands, 63 | args.n_bands).cuda() 64 | elif args.arch == 'TSFN': 65 | model = TSFN(args.scale_ratio, 66 | args.n_select_bands, 67 | args.n_bands).cuda() 68 | elif args.arch == 'MoGDCNx4': 69 | model = MoGDCNx4(scale_ratio=args.scale_ratio, 70 | n_select_bands=args.n_select_bands, 71 | n_bands=args.n_bands, 72 | img_size=args.image_size).cuda() 73 | elif args.arch == 'MoGDCN': 74 | model = MoGDCN(scale_ratio=args.scale_ratio, 75 | n_select_bands=args.n_select_bands, 76 | n_bands=args.n_bands, 77 | img_size=args.image_size).cuda() 78 | elif args.arch == 'MoGDCNx16': 79 | model = MoGDCNx16(scale_ratio=args.scale_ratio, 80 | n_select_bands=args.n_select_bands, 81 | n_bands=args.n_bands, 82 | img_size=args.image_size).cuda() 83 | elif args.arch == '_3DT_Net': 84 | model = _3DT_Net(args.scale_ratio, 8, 85 | args.n_bands,args.n_select_bands 86 | ).cuda() 87 | elif args.arch == 'MCIFNet': 88 | model = MCIFNet(img_size=64, 89 | patch_size=1, 90 | in_chans_MSI=args.n_select_bands, 91 | in_chans_HSI=args.n_bands, 92 | embed_dim=96, 93 | depths=(1,), 94 | mlp_dim=[256, 128], 95 | drop_rate=0., 96 | d_state = 16, 97 | mlp_ratio=2., 98 | drop_path_rate=0.1, 99 | norm_layer=nn.LayerNorm, 100 | patch_norm=True, 101 | use_checkpoint=False, 102 | upscale=2, 103 | img_range=1., 104 | upsampler='', 105 | resi_connection='1conv').cuda() 106 | 107 | # Loss and optimizer 108 | # criterion = nn.MSELoss().cuda() 109 | criterion = nn.L1Loss().cuda() 110 | #criterion = MAE_SAM_LOGloss().cuda() 111 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 112 | 113 | # Load the trained model parameters 114 | model_path = args.model_path.replace('dataset', args.dataset) \ 115 | .replace('arch', args.arch) 116 | if os.path.exists(model_path): 117 | model.load_state_dict(torch.load(model_path), strict=False) 118 | print ('Load the chekpoint of {}'.format(model_path)) 119 | recent_psnr = validate(test_list, 120 | args.arch, 121 | model, 122 | 0, 123 | args.n_epochs) 124 | print ('psnr: ', recent_psnr) 125 | 126 | best_psnr = 0 127 | best_psnr = validate(test_list, 128 | args.arch, 129 | model, 130 | 0, 131 | args.n_epochs) 132 | print ('psnr: ', best_psnr) 133 | 134 | # Epochs 135 | print ('Start Training: ') 136 | for epoch in range(args.n_epochs): 137 | # One epoch's training 138 | print ('Train_Epoch_{}: '.format(epoch)) 139 | h_str = h_rand[epoch] 140 | w_str = w_rand[epoch] 141 | 142 | 143 | train(train_list, 144 | args.image_size, 145 | args.scale_ratio, 146 | args.n_bands, 147 | args.arch, 148 | model, 149 | optimizer, 150 | criterion, 151 | epoch, 152 | args.n_epochs, 153 | h_str , 154 | w_str) 155 | 156 | # One epoch's validation 157 | print ('Val_Epoch_{}: '.format(epoch)) 158 | recent_psnr = validate(test_list, 159 | args.arch, 160 | model, 161 | epoch, 162 | args.n_epochs) 163 | print ('psnr: ', recent_psnr) 164 | print ('best_psnr: ', best_psnr) 165 | # # save model 166 | is_best = recent_psnr > best_psnr 167 | best_psnr = max(recent_psnr, best_psnr) 168 | if is_best: 169 | torch.save(model.state_dict(), model_path) 170 | print ('Saved!') 171 | print ('') 172 | 173 | print ('best_psnr: ', best_psnr) 174 | 175 | if __name__ == '__main__': 176 | main() 177 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def calc_ergas(img_tgt, img_fus): 6 | img_tgt = np.squeeze(img_tgt) 7 | img_fus = np.squeeze(img_fus) 8 | img_tgt = img_tgt.reshape(img_tgt.shape[0], -1) 9 | img_fus = img_fus.reshape(img_fus.shape[0], -1) 10 | 11 | rmse = np.mean((img_tgt-img_fus)**2, axis=1) 12 | rmse = rmse**0.5 13 | mean = np.mean(img_tgt, axis=1) 14 | 15 | ergas = np.mean((rmse/mean)**2) 16 | ergas = 100/4*ergas**0.5 17 | 18 | return ergas 19 | 20 | def calc_psnr(img_tgt, img_fus): 21 | mse = np.mean((img_tgt-img_fus)**2) 22 | img_max = np.max(img_tgt) 23 | psnr = 10*np.log10(img_max**2/mse) 24 | 25 | return psnr 26 | 27 | def calc_rmse(img_tgt, img_fus): 28 | rmse = np.sqrt(np.mean((img_tgt-img_fus)**2)) 29 | 30 | return rmse 31 | 32 | def calc_sam(img_tgt, img_fus): 33 | img_tgt = np.squeeze(img_tgt) 34 | img_fus = np.squeeze(img_fus) 35 | img_tgt = img_tgt.reshape(img_tgt.shape[0], -1) 36 | img_fus = img_fus.reshape(img_fus.shape[0], -1) 37 | img_tgt = img_tgt / np.max(img_tgt) 38 | img_fus = img_fus / np.max(img_fus) 39 | 40 | A = np.sqrt(np.sum(img_tgt**2, axis=0)) 41 | B = np.sqrt(np.sum(img_fus**2, axis=0)) 42 | AB = np.sum(img_tgt*img_fus, axis=0) 43 | 44 | sam = AB/(A*B) 45 | sam = np.arccos(sam) 46 | sam = np.mean(sam)*180/3.1415926535 47 | 48 | return sam 49 | -------------------------------------------------------------------------------- /models/MCIFNet.py: -------------------------------------------------------------------------------- 1 | # Code Implementation of the MambaIR Model 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.checkpoint as checkpoint 6 | import torch.nn.functional as F 7 | from functools import partial 8 | from typing import Optional, Callable 9 | # from basicsr.utils.registry import ARCH_REGISTRY 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 12 | from einops import rearrange, repeat 13 | 14 | 15 | 16 | 17 | 18 | def make_coord(shape, ranges=None, flatten=True): 19 | """ Make coordinates at grid centers. 20 | """ 21 | coord_seqs = [] 22 | for i, n in enumerate(shape): 23 | if ranges is None: 24 | v0, v1 = -1, 1 25 | else: 26 | v0, v1 = ranges[i] 27 | r = (v1 - v0) / (2 * n) 28 | seq = v0 + r + (2 * r) * torch.arange(n).float() 29 | coord_seqs.append(seq) 30 | ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) 31 | if flatten: 32 | ret = ret.view(-1, ret.shape[-1]) 33 | return ret 34 | 35 | class MLP(nn.Module): 36 | def __init__(self, in_dim, out_dim, hidden_list): 37 | super().__init__() 38 | layers = [] 39 | lastv = in_dim 40 | for hidden in hidden_list: 41 | layers.append(nn.Linear(lastv, hidden)) 42 | layers.append(nn.ReLU()) 43 | lastv = hidden 44 | layers.append(nn.Linear(lastv, out_dim)) 45 | self.layers = nn.Sequential(*layers) 46 | 47 | def forward(self, x): 48 | x = self.layers(x) 49 | return x 50 | 51 | class ChannelAttention(nn.Module): 52 | """Channel attention used in RCAN. 53 | Args: 54 | num_feat (int): Channel number of intermediate features. 55 | squeeze_factor (int): Channel squeeze factor. Default: 16. 56 | """ 57 | 58 | def __init__(self, num_feat, squeeze_factor=16): 59 | super(ChannelAttention, self).__init__() 60 | self.attention = nn.Sequential( 61 | nn.AdaptiveAvgPool2d(1), 62 | nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), 65 | nn.Sigmoid()) 66 | 67 | def forward(self, x): 68 | y = self.attention(x) 69 | return x * y 70 | 71 | 72 | class CAB(nn.Module): 73 | def __init__(self, num_feat, is_light_sr= False, compress_ratio=3,squeeze_factor=30): 74 | super(CAB, self).__init__() 75 | if is_light_sr: # a larger compression ratio is used for light-SR 76 | compress_ratio = 6 77 | self.cab = nn.Sequential( 78 | nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1), 79 | nn.GELU(), 80 | nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1), 81 | ChannelAttention(num_feat, squeeze_factor) 82 | ) 83 | 84 | def forward(self, x): 85 | return self.cab(x) 86 | 87 | 88 | class Mlp(nn.Module): 89 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 90 | super().__init__() 91 | out_features = out_features or in_features 92 | hidden_features = hidden_features or in_features 93 | self.fc1 = nn.Linear(in_features, hidden_features) 94 | self.act = act_layer() 95 | self.fc2 = nn.Linear(hidden_features, out_features) 96 | self.drop = nn.Dropout(drop) 97 | 98 | def forward(self, x): 99 | x = self.fc1(x) 100 | x = self.act(x) 101 | x = self.drop(x) 102 | x = self.fc2(x) 103 | x = self.drop(x) 104 | return x 105 | 106 | 107 | class DynamicPosBias(nn.Module): 108 | def __init__(self, dim, num_heads): 109 | super().__init__() 110 | self.num_heads = num_heads 111 | self.pos_dim = dim // 4 112 | self.pos_proj = nn.Linear(2, self.pos_dim) 113 | self.pos1 = nn.Sequential( 114 | nn.LayerNorm(self.pos_dim), 115 | nn.ReLU(inplace=True), 116 | nn.Linear(self.pos_dim, self.pos_dim), 117 | ) 118 | self.pos2 = nn.Sequential( 119 | nn.LayerNorm(self.pos_dim), 120 | nn.ReLU(inplace=True), 121 | nn.Linear(self.pos_dim, self.pos_dim) 122 | ) 123 | self.pos3 = nn.Sequential( 124 | nn.LayerNorm(self.pos_dim), 125 | nn.ReLU(inplace=True), 126 | nn.Linear(self.pos_dim, self.num_heads) 127 | ) 128 | 129 | def forward(self, biases): 130 | pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) 131 | return pos 132 | 133 | def flops(self, N): 134 | flops = N * 2 * self.pos_dim 135 | flops += N * self.pos_dim * self.pos_dim 136 | flops += N * self.pos_dim * self.pos_dim 137 | flops += N * self.pos_dim * self.num_heads 138 | return flops 139 | 140 | 141 | class Attention(nn.Module): 142 | r""" Multi-head self attention module with dynamic position bias. 143 | 144 | Args: 145 | dim (int): Number of input channels. 146 | num_heads (int): Number of attention heads. 147 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 148 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 149 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 150 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 151 | """ 152 | 153 | def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., 154 | position_bias=True): 155 | 156 | super().__init__() 157 | self.dim = dim 158 | self.num_heads = num_heads 159 | head_dim = dim // num_heads 160 | self.scale = qk_scale or head_dim ** -0.5 161 | self.position_bias = position_bias 162 | if self.position_bias: 163 | self.pos = DynamicPosBias(self.dim // 4, self.num_heads) 164 | 165 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 166 | self.attn_drop = nn.Dropout(attn_drop) 167 | self.proj = nn.Linear(dim, dim) 168 | self.proj_drop = nn.Dropout(proj_drop) 169 | 170 | self.softmax = nn.Softmax(dim=-1) 171 | 172 | def forward(self, x, H, W, mask=None): 173 | """ 174 | Args: 175 | x: input features with shape of (num_groups*B, N, C) 176 | mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None 177 | H: height of each group 178 | W: width of each group 179 | """ 180 | group_size = (H, W) 181 | B_, N, C = x.shape 182 | assert H * W == N 183 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() 184 | q, k, v = qkv[0], qkv[1], qkv[2] 185 | 186 | q = q * self.scale 187 | attn = (q @ k.transpose(-2, -1)) # (B_, self.num_heads, N, N), N = H*W 188 | 189 | if self.position_bias: 190 | # generate mother-set 191 | position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) 192 | position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) 193 | biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 194 | biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 195 | 196 | # get pair-wise relative position index for each token inside the window 197 | coords_h = torch.arange(group_size[0], device=attn.device) 198 | coords_w = torch.arange(group_size[1], device=attn.device) 199 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw 200 | coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw 201 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw 202 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 203 | relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 204 | relative_coords[:, :, 1] += group_size[1] - 1 205 | relative_coords[:, :, 0] *= 2 * group_size[1] - 1 206 | relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw 207 | 208 | pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads 209 | # select position bias 210 | relative_position_bias = pos[relative_position_index.view(-1)].view( 211 | group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH 212 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw 213 | attn = attn + relative_position_bias.unsqueeze(0) 214 | 215 | if mask is not None: 216 | nP = mask.shape[0] 217 | attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( 218 | 0) # (B, nP, nHead, N, N) 219 | attn = attn.view(-1, self.num_heads, N, N) 220 | attn = self.softmax(attn) 221 | else: 222 | attn = self.softmax(attn) 223 | 224 | attn = self.attn_drop(attn) 225 | 226 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 227 | x = self.proj(x) 228 | x = self.proj_drop(x) 229 | return x 230 | 231 | 232 | class SS2D(nn.Module): 233 | def __init__( 234 | self, 235 | d_model, 236 | d_state=16, 237 | d_conv=3, 238 | expand=2., 239 | dt_rank="auto", 240 | dt_min=0.001, 241 | dt_max=0.1, 242 | dt_init="random", 243 | dt_scale=1.0, 244 | dt_init_floor=1e-4, 245 | dropout=0., 246 | conv_bias=True, 247 | bias=False, 248 | device=None, 249 | dtype=None, 250 | **kwargs, 251 | ): 252 | factory_kwargs = {"device": device, "dtype": dtype} 253 | super().__init__() 254 | self.d_model = d_model 255 | self.d_state = d_state 256 | self.d_conv = d_conv 257 | self.expand = expand 258 | self.d_inner = int(self.expand * self.d_model) 259 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 260 | 261 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 262 | self.conv2d = nn.Conv2d( 263 | in_channels=self.d_inner, 264 | out_channels=self.d_inner, 265 | groups=self.d_inner, 266 | bias=conv_bias, 267 | kernel_size=d_conv, 268 | padding=(d_conv - 1) // 2, 269 | **factory_kwargs, 270 | ) 271 | self.act = nn.SiLU() 272 | 273 | self.x_proj = ( 274 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 275 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 276 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 277 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 278 | ) 279 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) 280 | del self.x_proj 281 | 282 | self.dt_projs = ( 283 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, 284 | **factory_kwargs), 285 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, 286 | **factory_kwargs), 287 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, 288 | **factory_kwargs), 289 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, 290 | **factory_kwargs), 291 | ) 292 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) 293 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) 294 | del self.dt_projs 295 | 296 | self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) 297 | self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) 298 | 299 | self.selective_scan = selective_scan_fn 300 | 301 | self.out_norm = nn.LayerNorm(self.d_inner) 302 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 303 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None 304 | 305 | @staticmethod 306 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, 307 | **factory_kwargs): 308 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 309 | 310 | # Initialize special dt projection to preserve variance at initialization 311 | dt_init_std = dt_rank ** -0.5 * dt_scale 312 | if dt_init == "constant": 313 | nn.init.constant_(dt_proj.weight, dt_init_std) 314 | elif dt_init == "random": 315 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 316 | else: 317 | raise NotImplementedError 318 | 319 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 320 | dt = torch.exp( 321 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 322 | + math.log(dt_min) 323 | ).clamp(min=dt_init_floor) 324 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 325 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 326 | with torch.no_grad(): 327 | dt_proj.bias.copy_(inv_dt) 328 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 329 | dt_proj.bias._no_reinit = True 330 | 331 | return dt_proj 332 | @staticmethod 333 | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): 334 | # S4D real initialization 335 | A = repeat( 336 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 337 | "n -> d n", 338 | d=d_inner, 339 | ).contiguous() 340 | A_log = torch.log(A) # Keep A_log in fp32 341 | if copies > 1: 342 | A_log = repeat(A_log, "d n -> r d n", r=copies) 343 | if merge: 344 | A_log = A_log.flatten(0, 1) 345 | A_log = nn.Parameter(A_log) 346 | A_log._no_weight_decay = True 347 | return A_log 348 | 349 | @staticmethod 350 | def D_init(d_inner, copies=1, device=None, merge=True): 351 | # D "skip" parameter 352 | D = torch.ones(d_inner, device=device) 353 | if copies > 1: 354 | D = repeat(D, "n1 -> r n1", r=copies) 355 | if merge: 356 | D = D.flatten(0, 1) 357 | D = nn.Parameter(D) # Keep in fp32 358 | D._no_weight_decay = True 359 | return D 360 | 361 | def forward_core(self, x: torch.Tensor): 362 | B, C, H, W = x.shape 363 | L = H * W 364 | K = 4 365 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 366 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (1, 4, 192, 3136) 367 | 368 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 369 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 370 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 371 | xs = xs.float().view(B, -1, L) 372 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 373 | Bs = Bs.float().view(B, K, -1, L) 374 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 375 | Ds = self.Ds.float().view(-1) 376 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) 377 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 378 | out_y = self.selective_scan( 379 | xs, dts, 380 | As, Bs, Cs, Ds, z=None, 381 | delta_bias=dt_projs_bias, 382 | delta_softplus=True, 383 | return_last_state=False, 384 | ).view(B, K, -1, L) 385 | assert out_y.dtype == torch.float 386 | 387 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 388 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 389 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 390 | 391 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 392 | 393 | def forward(self, x: torch.Tensor, **kwargs): 394 | B, H, W, C = x.shape 395 | 396 | xz = self.in_proj(x) 397 | x, z = xz.chunk(2, dim=-1) 398 | 399 | x = x.permute(0, 3, 1, 2).contiguous() 400 | x = self.act(self.conv2d(x)) 401 | y1, y2, y3, y4 = self.forward_core(x) 402 | assert y1.dtype == torch.float32 403 | y = y1 + y2 + y3 + y4 404 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 405 | y = self.out_norm(y) 406 | y = y * F.silu(z) 407 | out = self.out_proj(y) 408 | if self.dropout is not None: 409 | out = self.dropout(out) 410 | return out 411 | 412 | 413 | class VSSBlock(nn.Module): 414 | def __init__( 415 | self, 416 | hidden_dim: int = 0, 417 | drop_path: float = 0, 418 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 419 | attn_drop_rate: float = 0, 420 | d_state: int = 16, 421 | expand: float = 2., 422 | is_light_sr: bool = False, 423 | **kwargs, 424 | ): 425 | super().__init__() 426 | self.ln_1 = norm_layer(hidden_dim) 427 | self.self_attention = SS2D(d_model=hidden_dim, d_state=d_state,expand=expand,dropout=attn_drop_rate, **kwargs) 428 | self.drop_path = DropPath(drop_path) 429 | self.skip_scale= nn.Parameter(torch.ones(hidden_dim)) 430 | self.conv_blk = CAB(hidden_dim,is_light_sr) 431 | self.ln_2 = nn.LayerNorm(hidden_dim) 432 | self.skip_scale2 = nn.Parameter(torch.ones(hidden_dim)) 433 | 434 | 435 | 436 | def forward(self, input, x_size): 437 | # x [B,HW,C] 438 | B, L, C = input.shape 439 | input = input.view(B, *x_size, C).contiguous() # [B,H,W,C] 440 | x = self.ln_1(input) 441 | x = input*self.skip_scale + self.drop_path(self.self_attention(x)) 442 | x = x*self.skip_scale2 + self.conv_blk(self.ln_2(x).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous() 443 | x = x.view(B, -1, C).contiguous() 444 | return x 445 | 446 | 447 | class BasicLayer(nn.Module): 448 | """ The Basic MambaIR Layer in one Residual State Space Group 449 | Args: 450 | dim (int): Number of input channels. 451 | input_resolution (tuple[int]): Input resolution. 452 | depth (int): Number of blocks. 453 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 454 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 455 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 456 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 457 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 458 | """ 459 | 460 | def __init__(self, 461 | dim, 462 | input_resolution, 463 | depth, 464 | drop_path=0., 465 | d_state=16, 466 | mlp_ratio=2., 467 | norm_layer=nn.LayerNorm, 468 | downsample=None, 469 | use_checkpoint=False,is_light_sr=False): 470 | 471 | super().__init__() 472 | self.dim = dim 473 | self.input_resolution = input_resolution 474 | self.depth = depth 475 | self.mlp_ratio=mlp_ratio 476 | self.use_checkpoint = use_checkpoint 477 | 478 | # build blocks 479 | self.blocks = nn.ModuleList() 480 | for i in range(depth): 481 | self.blocks.append(VSSBlock( 482 | hidden_dim=dim, 483 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 484 | norm_layer=nn.LayerNorm, 485 | attn_drop_rate=0, 486 | d_state=d_state, 487 | expand=self.mlp_ratio, 488 | input_resolution=input_resolution,is_light_sr=is_light_sr)) 489 | 490 | # patch merging layer 491 | if downsample is not None: 492 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 493 | else: 494 | self.downsample = None 495 | 496 | def forward(self, x, x_size): 497 | for blk in self.blocks: 498 | if self.use_checkpoint: 499 | x = checkpoint.checkpoint(blk, x) 500 | else: 501 | x = blk(x, x_size) 502 | if self.downsample is not None: 503 | x = self.downsample(x) 504 | return x 505 | 506 | def extra_repr(self) -> str: 507 | return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' 508 | 509 | def flops(self): 510 | flops = 0 511 | for blk in self.blocks: 512 | flops += blk.flops() 513 | if self.downsample is not None: 514 | flops += self.downsample.flops() 515 | return flops 516 | 517 | 518 | # @ARCH_REGISTRY.register() 519 | 520 | 521 | 522 | class ResidualGroup(nn.Module): 523 | """Residual State Space Group (RSSG). 524 | 525 | Args: 526 | dim (int): Number of input channels. 527 | input_resolution (tuple[int]): Input resolution. 528 | depth (int): Number of blocks. 529 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 530 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 531 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 532 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 533 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 534 | img_size: Input image size. 535 | patch_size: Patch size. 536 | resi_connection: The convolutional block before residual connection. 537 | """ 538 | 539 | def __init__(self, 540 | dim, 541 | input_resolution, 542 | depth, 543 | d_state=16, 544 | mlp_ratio=4., 545 | drop_path=0., 546 | norm_layer=nn.LayerNorm, 547 | downsample=None, 548 | use_checkpoint=False, 549 | img_size=None, 550 | patch_size=None, 551 | resi_connection='1conv', 552 | is_light_sr = False): 553 | super(ResidualGroup, self).__init__() 554 | 555 | self.dim = dim 556 | self.input_resolution = input_resolution # [64, 64] 557 | 558 | self.residual_group = BasicLayer( 559 | dim=dim, 560 | input_resolution=input_resolution, 561 | depth=depth, 562 | d_state = d_state, 563 | mlp_ratio=mlp_ratio, 564 | drop_path=drop_path, 565 | norm_layer=norm_layer, 566 | downsample=downsample, 567 | use_checkpoint=use_checkpoint, 568 | is_light_sr = is_light_sr) 569 | 570 | # build the last conv layer in each residual state space group 571 | if resi_connection == '1conv': 572 | self.conv = nn.Conv2d(dim, dim, 3, 1, 1) 573 | elif resi_connection == '3conv': 574 | # to save parameters and memory 575 | self.conv = nn.Sequential( 576 | nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), 577 | nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), 578 | nn.Conv2d(dim // 4, dim, 3, 1, 1)) 579 | 580 | self.patch_embed = PatchEmbed( 581 | img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) 582 | 583 | self.patch_unembed = PatchUnEmbed( 584 | img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) 585 | 586 | def forward(self, x, x_size): 587 | return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x 588 | 589 | def flops(self): 590 | flops = 0 591 | flops += self.residual_group.flops() 592 | h, w = self.input_resolution 593 | flops += h * w * self.dim * self.dim * 9 594 | flops += self.patch_embed.flops() 595 | flops += self.patch_unembed.flops() 596 | 597 | return flops 598 | 599 | 600 | class PatchEmbed(nn.Module): 601 | r""" transfer 2D feature map into 1D token sequence 602 | 603 | Args: 604 | img_size (int): Image size. Default: None. 605 | patch_size (int): Patch token size. Default: None. 606 | in_chans (int): Number of input image channels. Default: 3. 607 | embed_dim (int): Number of linear projection output channels. Default: 96. 608 | norm_layer (nn.Module, optional): Normalization layer. Default: None 609 | """ 610 | 611 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 612 | super().__init__() 613 | img_size = to_2tuple(img_size) 614 | patch_size = to_2tuple(patch_size) 615 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 616 | self.img_size = img_size 617 | self.patch_size = patch_size 618 | self.patches_resolution = patches_resolution 619 | self.num_patches = patches_resolution[0] * patches_resolution[1] 620 | 621 | self.in_chans = in_chans 622 | self.embed_dim = embed_dim 623 | 624 | if norm_layer is not None: 625 | self.norm = norm_layer(embed_dim) 626 | else: 627 | self.norm = None 628 | 629 | def forward(self, x): 630 | x = x.flatten(2).transpose(1, 2) # b Ph*Pw c 631 | if self.norm is not None: 632 | x = self.norm(x) 633 | return x 634 | 635 | def flops(self): 636 | flops = 0 637 | h, w = self.img_size 638 | if self.norm is not None: 639 | flops += h * w * self.embed_dim 640 | return flops 641 | 642 | 643 | class PatchUnEmbed(nn.Module): 644 | r""" return 2D feature map from 1D token sequence 645 | 646 | Args: 647 | img_size (int): Image size. Default: None. 648 | patch_size (int): Patch token size. Default: None. 649 | in_chans (int): Number of input image channels. Default: 3. 650 | embed_dim (int): Number of linear projection output channels. Default: 96. 651 | norm_layer (nn.Module, optional): Normalization layer. Default: None 652 | """ 653 | 654 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 655 | super().__init__() 656 | img_size = to_2tuple(img_size) 657 | patch_size = to_2tuple(patch_size) 658 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 659 | self.img_size = img_size 660 | self.patch_size = patch_size 661 | self.patches_resolution = patches_resolution 662 | self.num_patches = patches_resolution[0] * patches_resolution[1] 663 | 664 | self.in_chans = in_chans 665 | self.embed_dim = embed_dim 666 | 667 | def forward(self, x, x_size): 668 | x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c 669 | return x 670 | 671 | def flops(self): 672 | flops = 0 673 | return flops 674 | 675 | 676 | 677 | class UpsampleOneStep(nn.Sequential): 678 | """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) 679 | Used in lightweight SR to save parameters. 680 | 681 | Args: 682 | scale (int): Scale factor. Supported scales: 2^n and 3. 683 | num_feat (int): Channel number of intermediate features. 684 | 685 | """ 686 | 687 | def __init__(self, scale, num_feat, num_out_ch): 688 | self.num_feat = num_feat 689 | m = [] 690 | m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) 691 | m.append(nn.PixelShuffle(scale)) 692 | super(UpsampleOneStep, self).__init__(*m) 693 | 694 | 695 | 696 | class Upsample(nn.Sequential): 697 | """Upsample module. 698 | 699 | Args: 700 | scale (int): Scale factor. Supported scales: 2^n and 3. 701 | num_feat (int): Channel number of intermediate features. 702 | """ 703 | 704 | def __init__(self, scale, num_feat): 705 | m = [] 706 | if (scale & (scale - 1)) == 0: # scale = 2^n 707 | for _ in range(int(math.log(scale, 2))): 708 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 709 | m.append(nn.PixelShuffle(2)) 710 | elif scale == 3: 711 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 712 | m.append(nn.PixelShuffle(3)) 713 | else: 714 | raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') 715 | super(Upsample, self).__init__(*m) 716 | 717 | 718 | class MCIFNet(nn.Module): 719 | r""" MambaIR Model 720 | A PyTorch impl of : `A Simple Baseline for Image Restoration with State Space Model `. 721 | 722 | Args: 723 | img_size (int | tuple(int)): Input image size. Default 64 724 | patch_size (int | tuple(int)): Patch size. Default: 1 725 | in_chans (int): Number of input image channels. Default: 3 726 | embed_dim (int): Patch embedding dimension. Default: 96 727 | d_state (int): num of hidden state in the state space model. Default: 16 728 | depths (tuple(int)): Depth of each RSSG 729 | drop_rate (float): Dropout rate. Default: 0 730 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 731 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 732 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 733 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 734 | upscale: Upscale factor. 2/3/4 for image SR, 1 for denoising 735 | img_range: Image range. 1. or 255. 736 | upsampler: The reconstruction reconstruction module. 'pixelshuffle'/None 737 | resi_connection: The convolutional block before residual connection. '1conv'/'3conv' 738 | """ 739 | def __init__(self, 740 | img_size=64, 741 | patch_size=1, 742 | in_chans_MSI=4, 743 | in_chans_HSI=103, 744 | embed_dim=96, 745 | depths=(1,), 746 | mlp_dim=[256, 128], 747 | drop_rate=0., 748 | d_state = 16, 749 | mlp_ratio=2., 750 | drop_path_rate=0.1, 751 | norm_layer=nn.LayerNorm, 752 | patch_norm=True, 753 | use_checkpoint=False, 754 | upscale=2, 755 | img_range=1., 756 | upsampler='', 757 | resi_connection='1conv', 758 | **kwargs): 759 | super(MCIFNet, self).__init__() 760 | num_in_ch_MSI = in_chans_MSI 761 | num_in_ch_HSI = in_chans_HSI 762 | num_out_ch = in_chans_HSI 763 | num_feat = 64 764 | self.img_range = img_range 765 | self.mlp_dim = mlp_dim 766 | 767 | 768 | self.softmax = nn.Softmax() 769 | self.upscale = upscale 770 | self.upsampler = upsampler 771 | self.mlp_ratio=mlp_ratio 772 | # ------------------------- 1, shallow feature extraction ------------------------- # 773 | self.conv_first_MSI = nn.Conv2d(num_in_ch_MSI, embed_dim, 3, 1, 1) 774 | self.conv_first_HSI = nn.Conv2d(num_in_ch_HSI, embed_dim, 3, 1, 1) 775 | # ------------------------- 2, deep feature extraction ------------------------- # 776 | self.num_layers = len(depths) 777 | self.embed_dim = embed_dim 778 | self.patch_norm = patch_norm 779 | self.num_features = embed_dim 780 | 781 | 782 | # transfer 2D feature map into 1D token sequence, pay attention to whether using normalization 783 | self.patch_embed = PatchEmbed( 784 | img_size=img_size, 785 | patch_size=patch_size, 786 | in_chans=embed_dim, 787 | embed_dim=embed_dim, 788 | norm_layer=norm_layer if self.patch_norm else None) 789 | num_patches = self.patch_embed.num_patches 790 | patches_resolution = self.patch_embed.patches_resolution 791 | self.patches_resolution = patches_resolution 792 | 793 | # return 2D feature map from 1D token sequence 794 | self.patch_unembed = PatchUnEmbed( 795 | img_size=img_size, 796 | patch_size=patch_size, 797 | in_chans=embed_dim, 798 | embed_dim=embed_dim, 799 | norm_layer=norm_layer if self.patch_norm else None) 800 | 801 | self.pos_drop = nn.Dropout(p=drop_rate) 802 | self.is_light_sr = True if self.upsampler=='pixelshuffledirect' else False 803 | # stochastic depth 804 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 805 | 806 | # build Residual State Space Group (RSSG) 807 | self.layers_MSI = nn.ModuleList() 808 | for i_layer in range(self.num_layers): # 6-layer 809 | layer = ResidualGroup( 810 | dim=embed_dim, 811 | input_resolution=(patches_resolution[0], patches_resolution[1]), 812 | depth=depths[i_layer], 813 | d_state = d_state, 814 | mlp_ratio=self.mlp_ratio, 815 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results 816 | norm_layer=norm_layer, 817 | downsample=None, 818 | use_checkpoint=use_checkpoint, 819 | img_size=img_size, 820 | patch_size=patch_size, 821 | resi_connection=resi_connection, 822 | is_light_sr = self.is_light_sr 823 | ) 824 | self.layers_MSI.append(layer) 825 | self.layers_HSI = nn.ModuleList() 826 | for i_layer in range(self.num_layers): # 6-layer 827 | layer = ResidualGroup( 828 | dim=embed_dim, 829 | input_resolution=(patches_resolution[0], patches_resolution[1]), 830 | depth=depths[i_layer], 831 | d_state = d_state, 832 | mlp_ratio=self.mlp_ratio, 833 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results 834 | norm_layer=norm_layer, 835 | downsample=None, 836 | use_checkpoint=use_checkpoint, 837 | img_size=img_size, 838 | patch_size=patch_size, 839 | resi_connection=resi_connection, 840 | is_light_sr = self.is_light_sr 841 | ) 842 | self.layers_HSI.append(layer) 843 | self.norm = norm_layer(self.num_features) 844 | 845 | # build the last conv layer in the end of all residual groups 846 | if resi_connection == '1conv': 847 | self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) 848 | elif resi_connection == '3conv': 849 | # to save parameters and memory 850 | self.conv_after_body = nn.Sequential( 851 | nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), 852 | nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), 853 | nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) 854 | 855 | # -------------------------3. high-quality image reconstruction ------------------------ # 856 | if self.upsampler == 'pixelshuffle': 857 | # for classical SR 858 | self.conv_before_upsample = nn.Sequential( 859 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) 860 | self.upsample = Upsample(upscale, num_feat) 861 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 862 | elif self.upsampler == 'pixelshuffledirect': 863 | # for lightweight SR (to save parameters) 864 | self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch) 865 | 866 | else: 867 | # for image denoising 868 | self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) 869 | 870 | 871 | imnet_in_dim = self.embed_dim + self.embed_dim + 2 872 | self.imnet = MLP(imnet_in_dim, out_dim=in_chans_HSI+1, hidden_list=self.mlp_dim) 873 | 874 | 875 | def forward_features_MSI(self, x): 876 | x_size = (x.shape[2], x.shape[3]) 877 | x = self.patch_embed(x) # N,L,C 878 | 879 | x = self.pos_drop(x) 880 | 881 | for layer in self.layers_MSI: 882 | x = layer(x, x_size) 883 | 884 | # x = self.norm(x) # b seq_len c 885 | x = self.patch_unembed(x, x_size) 886 | 887 | return x 888 | 889 | def forward_features_HSI(self, x): 890 | x_size = (x.shape[2], x.shape[3]) 891 | x = self.patch_embed(x) # N,L,C 892 | 893 | x = self.pos_drop(x) 894 | 895 | for layer in self.layers_HSI: 896 | x = layer(x, x_size) 897 | 898 | # x = self.norm(x) # b seq_len c 899 | x = self.patch_unembed(x, x_size) 900 | 901 | return x 902 | def query(self, feat, coord, hr_guide): 903 | 904 | # feat: [B, C, h, w] 905 | # coord: [B, N, 2], N <= H * W 906 | 907 | b, c, h, w = feat.shape # lr 7x128x8x8 908 | _, _, H, W = hr_guide.shape # hr 7x128x64x64 909 | coord = coord.expand(b, H * W, 2) 910 | B, N, _ = coord.shape 911 | 912 | # LR centers' coords 913 | feat_coord = make_coord((h, w), flatten=False).to(feat.device).permute(2, 0, 1).unsqueeze(0).expand(b, 2, h, w) 914 | 915 | q_guide_hr = F.grid_sample(hr_guide, coord.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)[:, :, 0, 916 | :].permute(0, 2, 1) # [B, N, C] 917 | 918 | rx = 1 / h 919 | ry = 1 / w 920 | 921 | preds = [] 922 | 923 | for vx in [-1, 1]: 924 | for vy in [-1, 1]: 925 | coord_ = coord.clone() 926 | 927 | coord_[:, :, 0] += (vx) * rx 928 | coord_[:, :, 1] += (vy) * ry 929 | 930 | # feat: [B, c, h, w], coord_: [B, N, 2] --> [B, 1, N, 2], out: [B, c, 1, N] --> [B, c, N] --> [B, N, c] 931 | q_feat = F.grid_sample(feat, coord_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)[:, :, 0, 932 | :].permute(0, 2, 1) # [B, N, c] 933 | q_coord = F.grid_sample(feat_coord, coord_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)[ 934 | :, :, 0, :].permute(0, 2, 1) # [B, N, 2] 935 | 936 | rel_coord = coord - q_coord 937 | rel_coord[:, :, 0] *= h 938 | rel_coord[:, :, 1] *= w 939 | 940 | inp = torch.cat([q_feat, q_guide_hr, rel_coord], dim=-1) 941 | 942 | pred = self.imnet(inp.view(B * N, -1)).view(B, N, -1) # [B, N, 2] 943 | preds.append(pred) 944 | 945 | preds = torch.stack(preds, dim=-1) # [B, N, 2, kk] 946 | weight = F.softmax(preds[:, :, -1, :], dim=-1) 947 | ret = (preds[:, :, 0:-1, :] * weight.unsqueeze(-2)).sum(-1, keepdim=True).squeeze(-1) 948 | ret = ret.permute(0, 2, 1).view(b, -1, H, W) 949 | 950 | return ret 951 | def forward(self, HSI, MSI): 952 | 953 | HSI_first = self.conv_first_HSI(HSI) 954 | 955 | 956 | HSI_encoder = self.forward_features_HSI(HSI_first) + HSI_first 957 | 958 | 959 | MSI_first = self.conv_first_MSI(MSI) 960 | MSI_encoder = self.forward_features_MSI(MSI_first) + MSI_first 961 | 962 | 963 | 964 | 965 | B, c, H, W = MSI.shape 966 | coord = make_coord([H, W]).cuda() 967 | # inrF = self.query(spe, coord, spa).reshape(B, -1, H*W) # BxCxHxW 968 | INR_F = self.query(HSI_encoder, coord, MSI_encoder) # BxCxHxW 969 | 970 | 971 | return INR_F,0,0,0,0,0 972 | 973 | 974 | 975 | if __name__ == '__main__': 976 | scale_ratio = 4 977 | n_select_bands = 4 978 | n_bands = 103 979 | MSI = torch.randn(1, n_select_bands, 128, 128).cuda() 980 | HSI = torch.randn(1, n_bands, 8, 8).cuda() 981 | # Create an instance of the Vim model 982 | model = MCIFNet(img_size=64, 983 | patch_size=1, 984 | in_chans_MSI=4, 985 | in_chans_HSI=103, 986 | embed_dim=96, 987 | depths=(1,), 988 | mlp_dim=[256, 128], 989 | drop_rate=0., 990 | d_state = 16, 991 | mlp_ratio=2., 992 | drop_path_rate=0.1, 993 | norm_layer=nn.LayerNorm, 994 | patch_norm=True, 995 | use_checkpoint=False, 996 | upscale=2, 997 | img_range=1., 998 | upsampler='', 999 | resi_connection='1conv').cuda() 1000 | 1001 | # Perform a forward pass through the model 1002 | out = model(HSI,MSI) 1003 | 1004 | # Print the shape and output of the forward pass 1005 | print(out[0].shape) 1006 | # print(flop_count_table(FlopCountAnalysis(model, (rgb, ms)))) 1007 | # print(out) 1008 | 1009 | 1010 | 1011 | 1012 | 1013 | 1014 | 1015 | 1016 | 1017 | 1018 | 1019 | 1020 | 1021 | 1022 | 1023 | 1024 | 1025 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | import torch.optim 5 | from torch import nn 6 | from models.TSFN import TSFN 7 | # from models.MambaINR import MambaINR 8 | from models.PSRTNet import PSRTnet 9 | from models._3DT_Net import _3DT_Net 10 | from models.SSRNET import SSRNET 11 | from models.HyperKite import HyperKite 12 | from models.MoGDCNx4 import MoGDCNx4 13 | from models.MoGDCN import MoGDCN 14 | from models.MoGDCNx16 import MoGDCNx16 15 | from models.DCT import DCT 16 | from utils import * 17 | from metrics import calc_psnr, calc_rmse, calc_ergas, calc_sam 18 | from data_loader import build_datasets 19 | from models.ASSMamba import ASSMamba 20 | from models.ASSMamba_no_RSSG import ASSMamba_no_RSSG 21 | from models.ASSMamba_no_CAB import ASSMamba_no_CAB 22 | from models.ASSMamba_no_GINS import ASSMamba_no_GINS 23 | from models.ASSMamba_no_VSSM import ASSMamba_no_VSSM 24 | from models.ASSMamba_no_SEFU import ASSMamba_no_SEFU 25 | from models.P3Net import P3Net 26 | from validate import validate 27 | from train import train 28 | import pdb 29 | import args_parser 30 | from torch.nn import functional as F 31 | import cv2 32 | from time import * 33 | import os 34 | import scipy.io as io 35 | from thop import profile 36 | torch.cuda.is_available() 37 | 38 | args = args_parser.args_parser() 39 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 40 | 41 | print (args) 42 | 43 | # torch.cuda.is_available() 44 | def main(): 45 | if args.dataset == 'PaviaU': 46 | args.n_bands = 103 47 | elif args.dataset == 'Pavia': 48 | args.n_bands = 102 49 | elif args.dataset == 'Chikusei': 50 | args.n_bands = 128 51 | elif args.dataset == 'IEEE2018': 52 | args.n_bands = 48 53 | elif args.dataset == 'Botswana': 54 | args.n_bands = 145 55 | 56 | # Custom dataloader 57 | train_list, test_list = build_datasets(args.root, 58 | args.dataset, 59 | args.image_size, 60 | args.n_select_bands, 61 | args.scale_ratio) 62 | 63 | # Build the models 64 | if args.dataset == 'PaviaU': 65 | args.n_bands = 103 66 | 67 | elif args.dataset == 'Pavia': 68 | args.n_bands = 102 69 | elif args.dataset == 'Chikusei': 70 | args.n_bands = 128 71 | elif args.dataset == 'IEEE2018': 72 | args.n_bands = 48 73 | elif args.dataset == 'Botswana': 74 | args.n_bands = 145 75 | # Build the models 76 | if args.arch == 'SSRNET' or args.arch == 'SpatRNET' or args.arch == 'SpecRNET': 77 | model = SSRNET(args.arch, 78 | args.scale_ratio, 79 | args.n_select_bands, 80 | args.n_bands).cuda() 81 | 82 | elif args.arch == 'PSRTnet': 83 | model = PSRTnet(args.scale_ratio, 84 | args.n_select_bands, 85 | args.n_bands, 86 | args.image_size).cuda() 87 | elif args.arch == 'DCT': 88 | model = DCT(n_colors=args.n_bands, upscale_factor=args.scale_ratio, n_feats=180) 89 | elif args.arch == 'HyperKite': 90 | model = HyperKite(args.scale_ratio, 91 | args.n_select_bands, 92 | args.n_bands).cuda() 93 | elif args.arch == 'TSFN': 94 | model = TSFN(args.scale_ratio, 95 | args.n_select_bands, 96 | args.n_bands).cuda() 97 | elif args.arch == 'MoGDCNx4': 98 | model = MoGDCNx4(scale_ratio=args.scale_ratio, 99 | n_select_bands=args.n_select_bands, 100 | n_bands=args.n_bands, 101 | img_size=args.image_size).cuda() 102 | elif args.arch == 'MoGDCN': 103 | model = MoGDCN(scale_ratio=args.scale_ratio, 104 | n_select_bands=args.n_select_bands, 105 | n_bands=args.n_bands, 106 | img_size=args.image_size).cuda() 107 | elif args.arch == 'MoGDCNx16': 108 | model = MoGDCNx16(scale_ratio=args.scale_ratio, 109 | n_select_bands=args.n_select_bands, 110 | n_bands=args.n_bands, 111 | img_size=args.image_size).cuda() 112 | elif args.arch == '_3DT_Net': 113 | model = _3DT_Net(args.scale_ratio, 8, 114 | args.n_bands,args.n_select_bands 115 | ).cuda() 116 | elif args.arch == 'ASSMamba_no_GINS': 117 | model = ASSMamba_no_GINS(img_size=64, 118 | patch_size=1, 119 | in_chans_MSI=args.n_select_bands, 120 | in_chans_HSI=args.n_bands, 121 | embed_dim=96, 122 | depths=(1,), 123 | mlp_dim=[256, 128], 124 | drop_rate=0., 125 | d_state = 16, 126 | mlp_ratio=2., 127 | drop_path_rate=0.1, 128 | norm_layer=nn.LayerNorm, 129 | patch_norm=True, 130 | use_checkpoint=False, 131 | upscale=2, 132 | img_range=1., 133 | upsampler='', 134 | resi_connection='1conv').cuda() 135 | elif args.arch == 'ASSMamba': 136 | model = ASSMamba(img_size=64, 137 | patch_size=1, 138 | in_chans_MSI=args.n_select_bands, 139 | in_chans_HSI=args.n_bands, 140 | embed_dim=96, 141 | depths=(1,), 142 | mlp_dim=[256, 128], 143 | drop_rate=0., 144 | d_state = 16, 145 | mlp_ratio=2., 146 | drop_path_rate=0.1, 147 | norm_layer=nn.LayerNorm, 148 | patch_norm=True, 149 | use_checkpoint=False, 150 | upscale=2, 151 | img_range=1., 152 | upsampler='', 153 | resi_connection='1conv').cuda() 154 | elif args.arch == 'ASSMamba_no_VSSM': 155 | model = ASSMamba_no_VSSM(img_size=64, 156 | patch_size=1, 157 | in_chans_MSI=args.n_select_bands, 158 | in_chans_HSI=args.n_bands, 159 | embed_dim=96, 160 | depths=(1,), 161 | mlp_dim=[256, 128], 162 | drop_rate=0., 163 | d_state = 16, 164 | mlp_ratio=2., 165 | drop_path_rate=0.1, 166 | norm_layer=nn.LayerNorm, 167 | patch_norm=True, 168 | use_checkpoint=False, 169 | upscale=2, 170 | img_range=1., 171 | upsampler='', 172 | resi_connection='1conv').cuda() 173 | elif args.arch == 'ASSMamba_no_SEFU': 174 | model = ASSMamba_no_SEFU(img_size=64, 175 | patch_size=1, 176 | in_chans_MSI=args.n_select_bands, 177 | in_chans_HSI=args.n_bands, 178 | embed_dim=96, 179 | depths=(1,), 180 | mlp_dim=[256, 128], 181 | drop_rate=0., 182 | d_state = 16, 183 | mlp_ratio=2., 184 | drop_path_rate=0.1, 185 | norm_layer=nn.LayerNorm, 186 | patch_norm=True, 187 | use_checkpoint=False, 188 | upscale=2, 189 | img_range=1., 190 | upsampler='', 191 | resi_connection='1conv').cuda() 192 | elif args.arch == 'ASSMamba_no_RSSG': 193 | model = ASSMamba_no_RSSG(img_size=64, 194 | patch_size=1, 195 | in_chans_MSI=args.n_select_bands, 196 | in_chans_HSI=args.n_bands, 197 | embed_dim=96, 198 | depths=(1,), 199 | mlp_dim=[256, 128], 200 | drop_rate=0., 201 | d_state = 16, 202 | mlp_ratio=2., 203 | drop_path_rate=0.1, 204 | norm_layer=nn.LayerNorm, 205 | patch_norm=True, 206 | use_checkpoint=False, 207 | upscale=2, 208 | img_range=1., 209 | upsampler='', 210 | resi_connection='1conv').cuda() 211 | elif args.arch == 'ASSMamba_no_CAB': 212 | model = ASSMamba_no_CAB(img_size=64, 213 | patch_size=1, 214 | in_chans_MSI=args.n_select_bands, 215 | in_chans_HSI=args.n_bands, 216 | embed_dim=96, 217 | depths=(1,), 218 | mlp_dim=[256, 128], 219 | drop_rate=0., 220 | d_state = 16, 221 | mlp_ratio=2., 222 | drop_path_rate=0.1, 223 | norm_layer=nn.LayerNorm, 224 | patch_norm=True, 225 | use_checkpoint=False, 226 | upscale=2, 227 | img_range=1., 228 | upsampler='', 229 | resi_connection='1conv').cuda() 230 | elif args.arch == 'P3Net': 231 | model = P3Net(scale_ratio=args.scale_ratio, 232 | img_size=64, 233 | patch_size=1, 234 | in_chans_MSI=args.n_select_bands, 235 | in_chans_HSI=args.n_bands, 236 | embed_dim=64, 237 | depths=(1,), 238 | mlp_dim=[256, 128], 239 | drop_rate=0., 240 | d_state = 16, 241 | mlp_ratio=2., 242 | drop_path_rate=0.1, 243 | norm_layer=nn.LayerNorm, 244 | patch_norm=True, 245 | use_checkpoint=False, 246 | upscale=2, 247 | img_range=1., 248 | upsampler='', 249 | resi_connection='1conv').cuda() 250 | # Load the trained model parameters 251 | model_path = args.model_path.replace('dataset', args.dataset) \ 252 | .replace('arch', args.arch) 253 | if os.path.exists(model_path): 254 | model.load_state_dict(torch.load(model_path), strict=False) 255 | print ('Load the chekpoint of {}'.format(model_path)) 256 | 257 | 258 | test_ref, test_lr, test_hr = test_list 259 | model.eval() 260 | 261 | # Set mini-batch dataset 262 | ref = test_ref.float().detach() 263 | lr = test_lr.float().detach() 264 | hr = test_hr.float().detach() 265 | 266 | begin_time = time() 267 | if args.arch == 'SSRNET': 268 | out, _, _, _, _, _ = model(lr.cuda(), hr.cuda()) 269 | elif args.arch == 'SpatRNET': 270 | _, out, _, _, _, _ = model(lr.cuda(), hr.cuda()) 271 | elif args.arch == 'SpecRNET': 272 | _, _, out, _, _, _ = model(lr.cuda(), hr.cuda()) 273 | elif args.arch == 'SwinCGAN': 274 | out, _, _, _, _, _ = model(lr.cuda(), hr.cuda(), args.scale_ratio) 275 | else: 276 | out, _, _, _, _, _ = model(lr.cuda(), hr.cuda()) 277 | end_time = time() 278 | run_time = (end_time-begin_time)*1000 279 | 280 | print () 281 | print () 282 | print ('Dataset: {}'.format(args.dataset)) 283 | print ('Arch: {}'.format(args.arch)) 284 | print ('ModelSize(M): {}'.format(np.around(os.path.getsize(model_path)//1024/1024.0, decimals=2))) 285 | print ('Time(Ms): {}'.format(np.around(run_time, decimals=2))) 286 | flops, params = profile(model, inputs=(lr.cuda(),hr.cuda())) 287 | flops = flops/1000000000 288 | print ('flops:',flops) 289 | print ('params:',params/1000000) 290 | 291 | ref = ref.detach().cpu().numpy() 292 | out = out.detach().cpu().numpy() 293 | 294 | slr = F.interpolate(lr, scale_factor=args.scale_ratio, mode='bilinear') 295 | slr = slr.detach().cpu().numpy() 296 | slr = np.squeeze(slr).transpose(1,2,0).astype(np.float64) 297 | 298 | sref = np.squeeze(ref).transpose(1,2,0).astype(np.float64) 299 | sout = np.squeeze(out).transpose(1,2,0).astype(np.float64) 300 | 301 | io.savemat('./实验结果的mat格式/'+args.dataset+'/'+ str(args.scale_ratio)+'倍'+'/'+args.arch+'.mat',{'Out':sout}) 302 | io.savemat('./实验结果的mat格式/'+args.dataset+'/'+ str(args.scale_ratio)+'倍'+'/'+'REF.mat',{'REF':sref}) 303 | io.savemat('./实验结果的mat格式/'+args.dataset+'/'+ str(args.scale_ratio)+'倍'+'/'+'Upsample.mat',{'Out':slr}) 304 | 305 | t_lr = np.squeeze(lr).detach().cpu().numpy().transpose(1,2,0).astype(np.float64) 306 | t_hr = np.squeeze(hr).detach().cpu().numpy().transpose(1,2,0).astype(np.float64) 307 | 308 | io.savemat('./为传统方法准备数据/'+args.dataset+'/'+ str(args.scale_ratio)+'倍'+'/'+'lr'+'.mat',{'HSI':t_lr}) 309 | io.savemat('./为传统方法准备数据/'+args.dataset+'/'+ str(args.scale_ratio)+'倍'+'/'+'hr'+'.mat',{'MSI':t_hr}) 310 | 311 | 312 | psnr = calc_psnr(ref, out) 313 | rmse = calc_rmse(ref, out) 314 | ergas = calc_ergas(ref, out) 315 | sam = calc_sam(ref, out) 316 | print ('RMSE: {:.4f};'.format(rmse)) 317 | print ('PSNR: {:.4f};'.format(psnr)) 318 | print ('ERGAS: {:.4f};'.format(ergas)) 319 | print ('SAM: {:.4f}.'.format(sam)) 320 | 321 | 322 | if __name__ == '__main__': 323 | main() 324 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils import to_var, batch_ids2words 4 | import random 5 | import torch.nn.functional as F 6 | import cv2 7 | 8 | 9 | def spatial_edge(x): 10 | edge1 = x[:, :, 0:x.size(2)-1, :] - x[:, :, 1:x.size(2), :] 11 | edge2 = x[:, :, :, 0:x.size(3)-1] - x[:, :, :, 1:x.size(3)] 12 | 13 | return edge1, edge2 14 | 15 | def spectral_edge(x): 16 | edge = x[:, 0:x.size(1)-1, :, :] - x[:, 1:x.size(1), :, :] 17 | 18 | return edge 19 | 20 | 21 | def train(train_list, 22 | image_size, 23 | scale_ratio, 24 | n_bands, 25 | arch, 26 | model, 27 | optimizer, 28 | criterion, 29 | epoch, 30 | n_epochs, 31 | h_str , 32 | w_str): 33 | # train_ref, train_lr, train_hr = train_list 34 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 35 | train_ref, train_lr, train_hr = train_list 36 | # h, w = train_ref.size(2), train_ref.size(3) 37 | # HH = [] 38 | # WW = [] 39 | # for i in range (10001): 40 | # h_str = random.randint(0, h-image_size-1) 41 | # HH.append(h_str) 42 | # w_str = random.randint(0, w-image_size-1) 43 | # WW.append(w_str) 44 | # h_str = random.randint(0, h-image_size-1) 45 | # w_str = random.randint(0, w-image_size-1) 46 | 47 | train_lr = train_ref[:, :, h_str:h_str+image_size, w_str:w_str+image_size] 48 | train_ref = train_ref[:, :, h_str:h_str+image_size, w_str:w_str+image_size] 49 | train_lr = F.interpolate(train_ref, scale_factor=1/(scale_ratio*1.0)) 50 | train_hr = train_hr[:, :, h_str:h_str+image_size, w_str:w_str+image_size] 51 | 52 | model.train() 53 | 54 | # Set mini-batch dataset 55 | image_lr = to_var(train_lr).detach() 56 | image_hr = to_var(train_hr).detach() 57 | image_ref = to_var(train_ref).detach() 58 | 59 | # Forward, Backward and Optimize 60 | optimizer.zero_grad() 61 | 62 | out, out_spat, out_spec, edge_spat1, edge_spat2, edge_spec = model(image_lr, image_hr) 63 | ref_edge_spat1, ref_edge_spat2 = spatial_edge(image_ref) 64 | ref_edge_spec = spectral_edge(image_ref) 65 | 66 | if 'RNET' in arch: 67 | loss_fus = criterion(out, image_ref) 68 | loss_spat = criterion(out_spat, image_ref) 69 | loss_spec = criterion(out_spec, image_ref) 70 | loss_spec_edge = criterion(edge_spec, ref_edge_spec) 71 | loss_spat_edge = 0.5*criterion(edge_spat1, ref_edge_spat1) + 0.5*criterion(edge_spat2, ref_edge_spat2) 72 | if arch == 'SpatRNET': 73 | loss = loss_spat + loss_spat_edge 74 | elif arch == 'SpecRNET': 75 | loss = loss_spec + loss_spec_edge 76 | elif arch == 'SSRNET': 77 | loss = loss_fus + loss_spat_edge + loss_spec_edge 78 | else: 79 | loss = criterion(out, image_ref) 80 | 81 | loss.backward() 82 | torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=0.5, norm_type=2)#设定梯度阈值,和L2正则 83 | optimizer.step() 84 | 85 | # Print log info 86 | print('Epoch [%d/%d], Loss: %.4f' 87 | %(epoch, 88 | n_epochs, 89 | loss, 90 | ) 91 | ) 92 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import torch 5 | # import h5py 6 | from cv2 import imread, resize 7 | from tqdm import tqdm 8 | from collections import Counter 9 | from random import seed, choice, sample 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | from torch.nn.utils.rnn import pack_padded_sequence 14 | 15 | 16 | 17 | def to_var(x, volatile=False): 18 | if torch.cuda.is_available(): 19 | x = x.cuda().float() 20 | return Variable(x, volatile=volatile) 21 | 22 | class AverageMeter(object): 23 | """ 24 | Keeps track of most recent, average, sum, and count of a metric. 25 | """ 26 | 27 | def __init__(self): 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = 0 32 | self.avg = 0 33 | self.sum = 0 34 | self.count = 0 35 | 36 | def update(self, val, n=1): 37 | self.val = val 38 | self.sum += val * n 39 | self.count += n 40 | self.avg = self.sum / self.count 41 | 42 | 43 | def adjust_learning_rate(optimizer, shrink_factor): 44 | """ 45 | Shrinks learning rate by a specified factor. 46 | 47 | :param optimizer: optimizer whose learning rate must be shrunk. 48 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 49 | """ 50 | 51 | print("\nDECAYING learning rate.") 52 | for param_group in optimizer.param_groups: 53 | param_group['lr'] = param_group['lr'] * shrink_factor 54 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 55 | 56 | 57 | def accuracy(scores, targets, k): 58 | """ 59 | Computes top-k accuracy, from predicted and true labels. 60 | 61 | :param scores: scores from the model 62 | :param targets: true labels 63 | :param k: k in top-k accuracy 64 | :return: top-k accuracy 65 | """ 66 | 67 | batch_size = targets.size(0) 68 | _, ind = scores.topk(k, 1, True, True) 69 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 70 | correct_total = correct.view(-1).float().sum() # 0D tensor 71 | return correct_total.item() * (100.0 / batch_size) 72 | 73 | 74 | def batch_ids2words(batch_ids, vocab): 75 | 76 | batch_words = [] 77 | for i in range(batch_ids.size(0)): 78 | sampled_caption = [] 79 | ids = batch_ids[i,::].cpu().data.numpy() 80 | 81 | for j in range(len(ids)): 82 | id = ids[j] 83 | word = vocab.idx2word[id] 84 | # if word == '.': 85 | # print ('.: ', id) 86 | if word == '': 87 | break 88 | if '' not in word: 89 | sampled_caption.append(word) 90 | 91 | for k in sampled_caption: 92 | if k==sampled_caption[0]: 93 | sentence = k 94 | else: 95 | sentence = sentence + ' ' + k 96 | 97 | sentence = u'{}'.format(sentence) if sampled_caption!=[] else u'.' 98 | batch_words.append(sentence) 99 | 100 | return batch_words 101 | 102 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from utils import * 3 | import cv2 4 | import pdb 5 | from metrics import calc_psnr, calc_rmse, calc_ergas, calc_sam 6 | import args_parser 7 | args = args_parser.args_parser() 8 | 9 | def validate(test_list, arch, model, epoch, n_epochs): 10 | test_ref, test_lr, test_hr = test_list 11 | model.eval() 12 | 13 | psnr = 0 14 | with torch.no_grad(): 15 | # Set mini-batch dataset 16 | ref = to_var(test_ref).detach() 17 | lr = to_var(test_lr).detach() 18 | hr = to_var(test_hr).detach() 19 | if arch == 'SSRNet': 20 | out, _, _, _, _, _ = model(lr, hr) 21 | elif arch == 'SSRSpat': 22 | _, out, _, _, _, _ = model(lr, hr) 23 | elif arch == 'SSRSpec': 24 | _, _, out, _, _, _ = model(lr, hr) 25 | elif arch == 'SwinCGAN': 26 | out, _, _, _, _, _ = model(lr, hr, args.scale_ratio) 27 | else: 28 | out, _, _, _, _, _ = model(lr, hr) 29 | 30 | ref = ref.detach().cpu().numpy() 31 | out = out.detach().cpu().numpy() 32 | 33 | rmse = calc_rmse(ref, out) 34 | psnr = calc_psnr(ref, out) 35 | ergas = calc_ergas(ref, out) 36 | sam = calc_sam(ref, out) 37 | 38 | with open('ConSSFCNN.txt', 'a') as f: 39 | f.write(str(epoch) + ',' + str(rmse) + ',' + str(psnr) + ',' + str(ergas) + ',' + str(sam) + ',' + '\n') 40 | 41 | return psnr --------------------------------------------------------------------------------