├── datasets └── .gitkeep ├── results └── .gitkeep ├── checkpoints └── .gitkeep ├── csi ├── utils │ ├── __init__.py │ ├── file_io.py │ ├── utils.py │ └── schedulers.py ├── engine │ ├── __init__.py │ └── defaults.py ├── config │ ├── __init__.py │ ├── config.py │ └── defaults.py ├── architectures │ ├── __init__.py │ └── dernn_lnlt.py ├── losses │ ├── __init__.py │ └── l1_loss.py ├── metrics │ ├── __init__.py │ └── metrics.py └── data │ ├── __init__.py │ └── data.py ├── visualization ├── real_results │ └── results │ │ └── .gitkeep ├── spectral_density.png ├── show_simulation.asv ├── show_simulation.m ├── show_real.m ├── show_real.asv ├── createfigure.m ├── show_line.m └── show_line.asv ├── .gitignore ├── Quality_Metrics ├── results │ └── README.md ├── csnr.m ├── Cal_quality_assessment.m ├── CC.m ├── SpectAngMapper.m ├── quality_assessment.m ├── img_qi.m └── cal_ssim.m ├── figures ├── teaser.png └── architecture.png ├── scripts ├── test_dernn_lnlt_5stg_real.sh ├── test_dernn_lnlt_5stg_simu.sh ├── test_dernn_lnlt_7stg_simu.sh ├── test_dernn_lnlt_9stg_simu.sh ├── test_dernn_lnlt_9stg_star_simu.sh ├── train_dernn_lnlt_5stg_real.sh ├── train_dernn_lnlt_5stg_simu.sh ├── train_dernn_lnlt_7stg_simu.sh ├── train_dernn_lnlt_9stg_simu.sh └── train_dernn_lnlt_9stg_star_simu.sh ├── configs ├── dernn_lnlt_7stg_simu.yaml ├── dernn_lnlt_9stg_simu.yaml ├── dernn_lnlt_9stg_star_simu.yaml ├── dernn_lnlt_5stg_real.yaml └── dernn_lnlt_5stg_simu.yaml ├── tools ├── test_real.py ├── test_simu.py └── train.py └── README.md /datasets/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /csi/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /visualization/real_results/results/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /csi/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import * 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | datasets/* 3 | .DS_Store 4 | exp/* -------------------------------------------------------------------------------- /csi/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import CfgNode, get_cfg -------------------------------------------------------------------------------- /csi/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .dernn_lnlt import DERNN_LNLT -------------------------------------------------------------------------------- /csi/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from csi.losses.l1_loss import CharbonnierLoss, TVLoss -------------------------------------------------------------------------------- /Quality_Metrics/results/README.md: -------------------------------------------------------------------------------- 1 | Please put the reconstructed HSI in this folder. -------------------------------------------------------------------------------- /csi/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from csi.metrics.metrics import torch_psnr, torch_ssim, sam -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShawnDong98/DERNN-LNLT/HEAD/figures/teaser.png -------------------------------------------------------------------------------- /figures/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShawnDong98/DERNN-LNLT/HEAD/figures/architecture.png -------------------------------------------------------------------------------- /scripts/test_dernn_lnlt_5stg_real.sh: -------------------------------------------------------------------------------- 1 | python tools/test_real.py \ 2 | --config-file configs/dernn_lnlt_5stg_real.yaml -------------------------------------------------------------------------------- /scripts/test_dernn_lnlt_5stg_simu.sh: -------------------------------------------------------------------------------- 1 | python tools/test_simu.py \ 2 | --config-file configs/dernn_lnlt_5stg_simu.yaml -------------------------------------------------------------------------------- /scripts/test_dernn_lnlt_7stg_simu.sh: -------------------------------------------------------------------------------- 1 | python tools/test_simu.py \ 2 | --config-file configs/dernn_lnlt_7stg_simu.yaml -------------------------------------------------------------------------------- /scripts/test_dernn_lnlt_9stg_simu.sh: -------------------------------------------------------------------------------- 1 | python tools/test_simu.py \ 2 | --config-file configs/dernn_lnlt_9stg_simu.yaml -------------------------------------------------------------------------------- /visualization/spectral_density.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShawnDong98/DERNN-LNLT/HEAD/visualization/spectral_density.png -------------------------------------------------------------------------------- /scripts/test_dernn_lnlt_9stg_star_simu.sh: -------------------------------------------------------------------------------- 1 | python tools/test_simu.py \ 2 | --config-file configs/dernn_lnlt_9stg_star_simu.yaml -------------------------------------------------------------------------------- /scripts/train_dernn_lnlt_5stg_real.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=5 2 | python tools/train.py \ 3 | --config-file configs/dernn_lnlt_5stg_real.yaml -------------------------------------------------------------------------------- /scripts/train_dernn_lnlt_5stg_simu.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=5 2 | python tools/train.py \ 3 | --config-file configs/dernn_lnlt_5stg_simu.yaml -------------------------------------------------------------------------------- /scripts/train_dernn_lnlt_7stg_simu.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=5 2 | python tools/train.py \ 3 | --config-file configs/dernn_lnlt_7stg_simu.yaml -------------------------------------------------------------------------------- /scripts/train_dernn_lnlt_9stg_simu.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=5 2 | python tools/train.py \ 3 | --config-file configs/dernn_lnlt_9stg_simu.yaml -------------------------------------------------------------------------------- /scripts/train_dernn_lnlt_9stg_star_simu.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=5 2 | python tools/train.py \ 3 | --config-file configs/dernn_lnlt_9stg_star_simu.yaml -------------------------------------------------------------------------------- /csi/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import CSITrainDataset, LoadTraining, LoadVal, shuffle_crop, LoadTSATestMeas 2 | from .data import generate_mask_3d, generate_mask_3d_shift 3 | from .data import shift, shift_back, gen_meas_torch 4 | from .data import shift_batch, shift_back_batch, gen_meas_torch_batch -------------------------------------------------------------------------------- /Quality_Metrics/csnr.m: -------------------------------------------------------------------------------- 1 | function s=csnr(A,B,row,col) 2 | 3 | [n,m,ch]=size(A); 4 | summa = 0; 5 | if ch==1 6 | e=A-B; 7 | e=e(row+1:n-row,col+1:m-col); 8 | me=mean(mean(e.^2)); 9 | s=10*log10(255^2/me); 10 | else 11 | for i=1:ch 12 | e=A-B; 13 | e=e(row+1:n-row,col+1:m-col,i); 14 | mse = mean(mean(e.^2)); 15 | s = 10*log10(255^2/mse); 16 | summa = summa + s; 17 | end 18 | s = summa/ch; 19 | end 20 | 21 | return; 22 | 23 | -------------------------------------------------------------------------------- /csi/utils/file_io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler 3 | from iopath.common.file_io import PathManager as PathManagerBase 4 | 5 | __all__ = ["PathManager", "PathHandler"] 6 | 7 | 8 | PathManager = PathManagerBase() 9 | """ 10 | This is a detectron2 project-specific PathManager. 11 | We try to stay away from global PathManager in fvcore as it 12 | introduces potential conflicts among other libraries. 13 | """ -------------------------------------------------------------------------------- /csi/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def checkpoint(model, ema, optimizer, scheduler, epoch, model_path, logger): 4 | save_dict = {} 5 | save_dict['model'] = model.state_dict() 6 | save_dict['ema'] = ema.state_dict() 7 | save_dict['optimizer'] = optimizer.state_dict() 8 | save_dict['scheduler'] = scheduler.state_dict() 9 | save_dict['epoch'] = epoch 10 | model_out_path = model_path + "/model_epoch_{}.pth".format(epoch) 11 | torch.save(save_dict, model_out_path) 12 | logger.info("Checkpoint saved to {}".format(model_out_path)) -------------------------------------------------------------------------------- /visualization/show_simulation.asv: -------------------------------------------------------------------------------- 1 | %% plot color pics 2 | clear; clc; 3 | load('./simulation_results/results/Test_result_9stage_share.mat'); 4 | save_file = './simulation_results/dluf_mixs2/dluf_mixs2_9stage_share/rgb_results/'; 5 | % load('./scene1_meas_H_truth.mat'); 6 | % save_file = './H_file/'; 7 | 8 | mkdir(save_file); 9 | close all; 10 | frame = 1; 11 | for i = 1:10 12 | recon = squeeze(H(i,:,:,:)); 13 | intensity = 5; 14 | for channel=1:28 15 | img_nb = [channel]; % channel number 16 | row_num = 1; col_num = 1; 17 | lam28 = [453.5 457.5 462.0 466.0 471.5 476.5 481.5 487.0 492.5 498.0 504.0 510.0... 18 | 516.0 522.5 529.5 536.5 544.0 551.5 558.5 567.5 575.5 584.5 594.5 604.0... 19 | 614.5 625.0 636.5 648.0]; 20 | recon(find(recon>1))=1; 21 | name = [save_file 'frame' num2str(frame) 'channel' num2str(channel)]; 22 | dispCubeAshwin(recon(:,:,img_nb),intensity,lam28(img_nb), [] ,col_num,row_num,0,1,name); 23 | end 24 | frame = frame+1; 25 | end 26 | close all; 27 | 28 | 29 | -------------------------------------------------------------------------------- /visualization/show_simulation.m: -------------------------------------------------------------------------------- 1 | %% plot color pics 2 | clear; clc; 3 | load('./simulation_results/results/dernn_lnlt_9stg_star_simu.mat'); 4 | save_file = './simulation_results/results/rgb_results/'; 5 | % save_file = './simulation_results/truth/'; 6 | % load('./scene1_meas_H_truth.mat'); 7 | % save_file = './H_file/'; 8 | 9 | mkdir(save_file); 10 | close all; 11 | for i = 1:10 12 | recon = squeeze(pred(i,:,:,:)); 13 | intensity = 5; 14 | for channel=1:28 15 | img_nb = [channel]; % channel number 16 | row_num = 1; col_num = 1; 17 | lam28 = [453.5 457.5 462.0 466.0 471.5 476.5 481.5 487.0 492.5 498.0 504.0 510.0... 18 | 516.0 522.5 529.5 536.5 544.0 551.5 558.5 567.5 575.5 584.5 594.5 604.0... 19 | 614.5 625.0 636.5 648.0]; 20 | recon(find(recon>1))=1; 21 | name = [save_file 'frame' num2str(i) 'channel' num2str(channel)]; 22 | dispCubeAshwin(recon(:,:,img_nb),intensity,lam28(img_nb), [] ,col_num,row_num,0,1,name); 23 | end 24 | end 25 | close all; 26 | 27 | 28 | -------------------------------------------------------------------------------- /Quality_Metrics/Cal_quality_assessment.m: -------------------------------------------------------------------------------- 1 | clear;clc; 2 | res_path = '../results/dernn_lnlt_9stg_star_simu.mat'; 3 | load(res_path); 4 | 5 | psnr_total=0.0; 6 | ssim_total=0.0; 7 | psnr_list = zeros(10,1); 8 | ssim_list = zeros(10,1); 9 | for i=1:10 10 | Z = squeeze(pred(i,:,:,:)); 11 | Z = double(Z); 12 | S = squeeze(truth(i,:,:,:)); 13 | S = double(S); 14 | 15 | Z(Z>1.0) = 1.0; 16 | Z(Z<0.0) = 0.0; 17 | 18 | [psnr, rmse, ergas, sam, uiqi, ssim] = quality_assessment(double(im2uint8(S)), double(im2uint8(Z)), 0, 1); 19 | 20 | pred(1,i) = psnr; 21 | pred(2,i) = rmse; 22 | pred(3,i) = ergas; 23 | pred(4,i) = sam; 24 | pred(5,i) = uiqi; 25 | pred(6,i) = ssim; 26 | psnr_list(i,1) = psnr; 27 | ssim_list(i,1) = ssim; 28 | psnr 29 | ssim 30 | 31 | psnr_total = psnr_total+psnr; 32 | ssim_total = ssim_total+ssim; 33 | end 34 | psnr = mean(psnr_list); 35 | ssim = mean(ssim_list); 36 | fprintf('The PNSR=%f\n',psnr); 37 | fprintf('The SSIM=%f\n',ssim); 38 | 39 | -------------------------------------------------------------------------------- /csi/losses/l1_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CharbonnierLoss(nn.Module): 7 | """Charbonnier Loss (L1)""" 8 | 9 | def __init__(self, eps=1e-3): 10 | super(CharbonnierLoss, self).__init__() 11 | self.eps = eps 12 | 13 | def forward(self, x, y): 14 | diff = x - y 15 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 16 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 17 | return loss 18 | 19 | 20 | class TVLoss(nn.Module): 21 | def __init__(self, weight: float=1) -> None: 22 | """Total Variation Loss 23 | Args: 24 | weight (float): weight of TV loss 25 | """ 26 | super().__init__() 27 | self.weight = weight 28 | 29 | def forward(self, x): 30 | batch_size, c, h, w = x.size() 31 | tv_h = torch.abs(x[:,:,1:,:] - x[:,:,:-1,:]).sum() 32 | tv_w = torch.abs(x[:,:,:,1:] - x[:,:,:,:-1]).sum() 33 | return self.weight * (tv_h + tv_w) / (batch_size * c * h * w) -------------------------------------------------------------------------------- /Quality_Metrics/CC.m: -------------------------------------------------------------------------------- 1 | function out = CC(ref,tar,mask) 2 | %-------------------------------------------------------------------------- 3 | % Cross Correlation 4 | % 5 | % USAGE 6 | % out = CC(ref,tar,mask) 7 | % 8 | % INPUT 9 | % ref : reference HS data (rows,cols,bands) 10 | % tar : target HS data (rows,cols,bands) 11 | % mask: binary mask (rows,cols) (optional) 12 | % 13 | % OUTPUT 14 | % out : cross correlations (bands) 15 | % 16 | %-------------------------------------------------------------------------- 17 | 18 | if nargin==2 19 | [rows,cols,bands] = size(tar); 20 | 21 | out = zeros(1,bands); 22 | for i = 1:bands 23 | tar_tmp = tar(:,:,i); 24 | ref_tmp = ref(:,:,i); 25 | cc = corrcoef(tar_tmp(:),ref_tmp(:)); 26 | out(1,i) = cc(1,2); 27 | end 28 | 29 | else 30 | [rows,cols,bands] = size(tar); 31 | 32 | out = zeros(1,bands); 33 | mask = find(mask~=0); 34 | for i = 1:bands 35 | tar_tmp = tar(:,:,i); 36 | ref_tmp = ref(:,:,i); 37 | cc = corrcoef(tar_tmp(mask),ref_tmp(mask)); 38 | out(1,i) = cc(1,2); 39 | end 40 | end 41 | out=mean(out); 42 | -------------------------------------------------------------------------------- /visualization/show_real.m: -------------------------------------------------------------------------------- 1 | %% plot color pics 2 | clear; clc; 3 | close all 4 | 5 | load('../results/dernn_lnlt_5stg_real.mat'); 6 | x_result_1 = flip(flip(squeeze(pred(1, :, :, :)),1),2); 7 | x_result_2 = flip(flip(squeeze(pred(2, :, :, :)),1),2); 8 | x_result_3 = flip(flip(squeeze(pred(3, :, :, :)),1),2); 9 | x_result_4 = flip(flip(squeeze(pred(4, :, :, :)),1),2); 10 | x_result_5 = flip(flip(squeeze(pred(5, :, :, :)),1),2); 11 | 12 | 13 | save_file = './real_results/results/rgb_results/'; 14 | mkdir(save_file); 15 | 16 | frame = 1; 17 | for recon = {x_result_1,x_result_2,x_result_3,x_result_4,x_result_5} 18 | recon = cell2mat(recon); 19 | intensity = 5; 20 | for channel=1:28 21 | img_nb = [channel]; % channel number 22 | row_num = 1; col_num = 1; 23 | lam28 = [453.5 457.5 462.0 466.0 471.5 476.5 481.5 487.0 492.5 498.0 504.0 510.0 516.0 522.5 529.5 536.5 544.0 551.5 558.5 567.5 575.5 584.5 594.5 604.0 614.5 625.0 636.5 648.0]; 24 | recon(find(recon>1))=1; 25 | name = [save_file 'frame' num2str(frame) 'channel' num2str(channel)]; 26 | dispCubeAshwin(recon(:,:,img_nb),intensity,lam28(img_nb), [] ,col_num,row_num,0,1,name); 27 | end 28 | frame = frame+1; 29 | end -------------------------------------------------------------------------------- /Quality_Metrics/SpectAngMapper.m: -------------------------------------------------------------------------------- 1 | function sam = SpectAngMapper(imagery1, imagery2) 2 | 3 | %========================================================================== 4 | % Evaluates the mean Spectral Angle Mapper (SAM)[1] for two MSIs. 5 | % 6 | % Syntax: 7 | % [psnr, ssim, fsim, ergas, msam ] = MSIQA(imagery1, imagery2) 8 | % 9 | % Input: 10 | % imagery1 - the reference MSI data array 11 | % imagery2 - the target MSI data array 12 | % NOTE: MSI data array is a M*N*K array for imagery with M*N spatial 13 | % pixels, K bands and DYNAMIC RANGE [0, 255]. If imagery1 and imagery2 14 | % have different size, the larger one will be truncated to fit the 15 | % smaller one. 16 | % 17 | % [1] R. YUHAS, J. BOARDMAN, and A. GOETZ, "Determination of semi-arid 18 | % landscape endmembers and seasonal trends using convex geometry 19 | % spectral unmixing techniques", JPL, Summaries of the 4 th Annual JPL 20 | % Airborne Geoscience Workshop. 1993. 21 | % 22 | % See also StructureSIM, FeatureSIM and ErrRelGlobAdimSyn 23 | % 24 | % by Yi Peng 25 | %========================================================================== 26 | 27 | tmp = (sum(imagery1.*imagery2, 3) + eps) ... 28 | ./ (sqrt(sum(imagery1.^2, 3)) + eps) ./ (sqrt(sum(imagery2.^2, 3)) + eps); 29 | sam = mean2(real(acos(tmp))); -------------------------------------------------------------------------------- /configs/dernn_lnlt_7stg_simu.yaml: -------------------------------------------------------------------------------- 1 | DATASETS: 2 | STEP: 2 3 | WAVE_LENS: 28 4 | MASK_TYPE: "mask_3d_shift" 5 | 6 | TRAIN: 7 | ITERATION: 8000 8 | CROP_SIZE: [256, 256] 9 | WITH_NOISE: False 10 | PATHS: 11 | - "./datasets/CSI/cave_1024_28" 12 | MASK_PATH: "./datasets/CSI/TSA_simu_data/mask_3d_shift.mat" 13 | RANDOM_MASK: False 14 | VAL: 15 | PATH: "./datasets/CSI/TSA_simu_data/Truth/" 16 | MASK_PATH: "./datasets/CSI/TSA_simu_data/mask_3d_shift.mat" 17 | TEST: 18 | PATH: "./datasets/CSI/TSA_real_data/Measurements/" 19 | MASK_PATH: "./datasets/CSI/TSA_real_data/mask_3d_shift.mat" 20 | 21 | 22 | DATALOADER: 23 | BATCH_SIZE: 8 24 | NUM_WORKERS: 8 25 | 26 | MODEL: 27 | DENOISER: 28 | TYPE: "DERNN_LNLT" 29 | DERNN_LNLT: 30 | IN_DIM: 29 31 | DIM: 28 32 | OUT_DIM: 28 33 | WINDOW_SIZE: [8, 8] 34 | WINDOW_NUM: [8, 8] 35 | LOCAL: True 36 | NON_LOCAL: True 37 | NUM_BLOCKS: [1, 1, 1, 1, 1] 38 | LAYERNORM_TYPE: "WithBias" 39 | FFN_NAME: "Gated_Dconv_FeedForward" 40 | STAGE: 7 41 | SHARE_PARAMS: True 42 | 43 | 44 | OPTIMIZER: 45 | LR: 1e-3 46 | 47 | 48 | DEBUG: True 49 | OUTPUT_DIR: "./exp/DERNN_LNLT_7stg_simu/" 50 | PRETRAINED_CKPT_PATH: "./checkpoints/dernn_lnlt_7stg_simu.pth" -------------------------------------------------------------------------------- /configs/dernn_lnlt_9stg_simu.yaml: -------------------------------------------------------------------------------- 1 | DATASETS: 2 | STEP: 2 3 | WAVE_LENS: 28 4 | MASK_TYPE: "mask_3d_shift" 5 | 6 | TRAIN: 7 | ITERATION: 3000 8 | CROP_SIZE: [256, 256] 9 | WITH_NOISE: False 10 | PATHS: 11 | - "./datasets/CSI/cave_1024_28" 12 | MASK_PATH: "./datasets/CSI/TSA_simu_data/mask_3d_shift.mat" 13 | RANDOM_MASK: False 14 | VAL: 15 | PATH: "./datasets/CSI/TSA_simu_data/Truth/" 16 | MASK_PATH: "./datasets/CSI/TSA_simu_data/mask_3d_shift.mat" 17 | TEST: 18 | PATH: "./datasets/CSI/TSA_real_data/Measurements/" 19 | MASK_PATH: "./datasets/CSI/TSA_real_data/mask_3d_shift.mat" 20 | 21 | 22 | DATALOADER: 23 | BATCH_SIZE: 3 24 | NUM_WORKERS: 8 25 | 26 | MODEL: 27 | DENOISER: 28 | TYPE: "DERNN_LNLT" 29 | DERNN_LNLT: 30 | IN_DIM: 29 31 | DIM: 28 32 | OUT_DIM: 28 33 | WINDOW_SIZE: [8, 8] 34 | WINDOW_NUM: [8, 8] 35 | LOCAL: True 36 | NON_LOCAL: True 37 | NUM_BLOCKS: [1, 1, 1, 1, 1] 38 | LAYERNORM_TYPE: "WithBias" 39 | FFN_NAME: "Gated_Dconv_FeedForward" 40 | STAGE: 9 41 | SHARE_PARAMS: True 42 | 43 | 44 | OPTIMIZER: 45 | LR: 1e-3 46 | 47 | 48 | DEBUG: True 49 | OUTPUT_DIR: "./exp/DERNN_LNLT_9stg_simu/" 50 | RESUME_CKPT_PATH: "" 51 | PRETRAINED_CKPT_PATH: "./checkpoints/dernn_lnlt_9stg_simu.pth" -------------------------------------------------------------------------------- /configs/dernn_lnlt_9stg_star_simu.yaml: -------------------------------------------------------------------------------- 1 | DATASETS: 2 | STEP: 2 3 | WAVE_LENS: 28 4 | MASK_TYPE: "mask_3d_shift" 5 | 6 | TRAIN: 7 | ITERATION: 3000 8 | CROP_SIZE: [256, 256] 9 | WITH_NOISE: False 10 | PATHS: 11 | - "./datasets/CSI/cave_1024_28" 12 | MASK_PATH: "./datasets/CSI/TSA_simu_data/mask_3d_shift.mat" 13 | RANDOM_MASK: False 14 | VAL: 15 | PATH: "./datasets/CSI/TSA_simu_data/Truth/" 16 | MASK_PATH: "./datasets/CSI/TSA_simu_data/mask_3d_shift.mat" 17 | TEST: 18 | PATH: "./datasets/CSI/TSA_real_data/Measurements/" 19 | MASK_PATH: "./datasets/CSI/TSA_real_data/mask_3d_shift.mat" 20 | 21 | 22 | DATALOADER: 23 | BATCH_SIZE: 3 24 | NUM_WORKERS: 8 25 | 26 | MODEL: 27 | DENOISER: 28 | TYPE: "DERNN_LNLT" 29 | DERNN_LNLT: 30 | IN_DIM: 29 31 | DIM: 28 32 | OUT_DIM: 28 33 | WINDOW_SIZE: [8, 8] 34 | WINDOW_NUM: [8, 8] 35 | LOCAL: True 36 | NON_LOCAL: True 37 | NUM_BLOCKS: [2, 2, 2, 2, 2] 38 | LAYERNORM_TYPE: "WithBias" 39 | FFN_NAME: "Gated_Dconv_FeedForward" 40 | STAGE: 9 41 | SHARE_PARAMS: True 42 | 43 | 44 | OPTIMIZER: 45 | LR: 1e-3 46 | 47 | 48 | DEBUG: True 49 | OUTPUT_DIR: "./exp/DERNN_LNLT_9stg_star_simu/" 50 | RESUME_CKPT_PATH: "" 51 | PRETRAINED_CKPT_PATH: "./checkpoints/dernn_lnlt_9stg_star_simu.pth" -------------------------------------------------------------------------------- /configs/dernn_lnlt_5stg_real.yaml: -------------------------------------------------------------------------------- 1 | DATASETS: 2 | STEP: 2 3 | WAVE_LENS: 28 4 | MASK_TYPE: "mask_3d_shift" 5 | 6 | TRAIN: 7 | ITERATION: 1000 8 | CROP_SIZE: [660, 660] 9 | WITH_NOISE: True 10 | PATHS: 11 | - "./datasets/CSI/cave_1024_28" 12 | - "./datasets/CSI/KAIST_CVPR2021" 13 | MASK_PATH: "./datasets/CSI/TSA_real_data/mask_3d_shift.mat" 14 | RANDOM_MASK: False 15 | VAL: 16 | PATH: "./datasets/CSI/TSA_real_data/Truth/" 17 | MASK_PATH: "./datasets/CSI/TSA_real_data/mask_3d_shift.mat" 18 | TEST: 19 | PATH: "./datasets/CSI/TSA_real_data/Measurements/" 20 | MASK_PATH: "./datasets/CSI/TSA_real_data/mask_3d_shift.mat" 21 | 22 | 23 | DATALOADER: 24 | BATCH_SIZE: 1 25 | NUM_WORKERS: 8 26 | 27 | MODEL: 28 | DENOISER: 29 | TYPE: "DERNN_LNLT" 30 | DERNN_LNLT: 31 | IN_DIM: 29 32 | DIM: 28 33 | OUT_DIM: 28 34 | WINDOW_SIZE: [14, 14] 35 | WINDOW_NUM: [14, 14] 36 | LOCAL: True 37 | NON_LOCAL: True 38 | NUM_BLOCKS: [1, 1, 1, 1, 1] 39 | LAYERNORM_TYPE: "WithBias" 40 | FFN_NAME: "Gated_Dconv_FeedForward" 41 | STAGE: 5 42 | SHARE_PARAMS: True 43 | 44 | 45 | OPTIMIZER: 46 | LR: 1e-3 47 | 48 | 49 | DEBUG: True 50 | OUTPUT_DIR: "./exp/DERNN_LNLT_5stg_real/" 51 | PRETRAINED_CKPT_PATH: "./checkpoints/dernn_lnlt_5stg_real.pth" -------------------------------------------------------------------------------- /configs/dernn_lnlt_5stg_simu.yaml: -------------------------------------------------------------------------------- 1 | DATASETS: 2 | STEP: 2 3 | WAVE_LENS: 28 4 | MASK_TYPE: "mask_3d_shift" 5 | 6 | TRAIN: 7 | ITERATION: 10000 8 | CROP_SIZE: [256, 256] 9 | WITH_NOISE: False 10 | PATHS: 11 | - "./datasets/CSI/cave_1024_28" 12 | MASK_PATH: "./datasets/CSI/TSA_simu_data/mask_3d_shift.mat" 13 | RANDOM_MASK: False 14 | VAL: 15 | PATH: "./datasets/CSI/TSA_simu_data/Truth/" 16 | MASK_PATH: "./datasets/CSI/TSA_simu_data/mask_3d_shift.mat" 17 | TEST: 18 | PATH: "./datasets/CSI/TSA_real_data/Measurements/" 19 | MASK_PATH: "./datasets/CSI/TSA_real_data/mask_3d_shift.mat" 20 | 21 | 22 | DATALOADER: 23 | BATCH_SIZE: 10 24 | NUM_WORKERS: 8 25 | 26 | MODEL: 27 | DENOISER: 28 | TYPE: "DERNN_LNLT" 29 | DERNN_LNLT: 30 | IN_DIM: 29 31 | DIM: 28 32 | OUT_DIM: 28 33 | WINDOW_SIZE: [8, 8] 34 | WINDOW_NUM: [8, 8] 35 | LOCAL: True 36 | NON_LOCAL: True 37 | WITH_DL: True 38 | WITH_MU: True 39 | WITH_NOISE_LEVEL: True 40 | NUM_BLOCKS: [1, 1, 1, 1, 1] 41 | LAYERNORM_TYPE: "WithBias" 42 | FFN_NAME: "Gated_Dconv_FeedForward" 43 | STAGE: 5 44 | SHARE_PARAMS: True 45 | 46 | 47 | OPTIMIZER: 48 | LR: 1e-3 49 | 50 | 51 | DEBUG: False 52 | OUTPUT_DIR: "./exp/DERNN_LNLT_5stg_simu/" 53 | RESUME_CKPT_PATH: "" 54 | PRETRAINED_CKPT_PATH: "./checkpoints/dernn_lnlt_5stg_simu.pth" -------------------------------------------------------------------------------- /Quality_Metrics/quality_assessment.m: -------------------------------------------------------------------------------- 1 | function [psnr,rmse, ergas, sam, uiqi,ssim,DD,CCS] = quality_assessment(ground_truth, estimated, ignore_edges, ratio_ergas) 2 | 3 | % Ignore borders 4 | y = ground_truth(ignore_edges+1:end-ignore_edges, ignore_edges+1:end-ignore_edges, :); 5 | x = estimated(ignore_edges+1:end-ignore_edges, ignore_edges+1:end-ignore_edges, :); 6 | 7 | % Size, bands, samples 8 | sz_x = size(x); 9 | n_bands = sz_x(3); 10 | n_samples = sz_x(1)*sz_x(2); 11 | 12 | % RMSE 13 | aux = sum(sum((x - y).^2, 1), 2)/n_samples; 14 | rmse_per_band = sqrt(aux); 15 | rmse = sqrt(sum(aux, 3)/n_bands); 16 | 17 | % ERGAS 18 | mean_y = sum(sum(y, 1), 2)/n_samples; 19 | ergas = 100*ratio_ergas*sqrt(sum((rmse_per_band ./ mean_y).^2)/n_bands); 20 | 21 | % SAM 22 | sam= SpectAngMapper( ground_truth, estimated ); 23 | sam=sam*180/pi; 24 | % num = sum(x .* y, 3); 25 | % den = sqrt(sum(x.^2, 3) .* sum(y.^2, 3)); 26 | % sam = sum(sum(acosd(num ./ den)))/(n_samples); 27 | 28 | % UIQI - calls the method described in "A Universal Image Quality Index" 29 | % by Zhou Wang and Alan C. Bovik 30 | q_band = zeros(1, n_bands); 31 | for idx1=1:n_bands 32 | q_band(idx1)=img_qi(ground_truth(:,:,idx1), estimated(:,:,idx1), 32); 33 | end 34 | uiqi = mean(q_band); 35 | ssim=cal_ssim(ground_truth, estimated,0,0); 36 | DD=norm(ground_truth(:)-estimated(:),1)/numel(ground_truth); 37 | CCS = CC(ground_truth,estimated); 38 | CCS=mean(CCS); 39 | psnr=csnr(ground_truth, estimated,0,0); -------------------------------------------------------------------------------- /visualization/show_real.asv: -------------------------------------------------------------------------------- 1 | %% plot color pics 2 | clear; clc; 3 | close all 4 | 5 | load('../results/real_gapnet.mat'); 6 | x_result_1 = flip(flip(squeeze(pred(1, :, :, :)),1),2); 7 | x_result_2 = flip(flip(squeeze(pred(2, :, :, :)),1),2); 8 | x_result_3 = flip(flip(squeeze(pred(3, :, :, :)),1),2); 9 | x_result_4 = flip(flip(squeeze(pred(4, :, :, :)),1),2); 10 | x_result_5 = flip(flip(squeeze(pred(5, :, :, :)),1),2); 11 | 12 | % x_result_1 = flip(flip(squeeze(test(1, :, :, :)),1),2); 13 | % x_result_2 = flip(flip(squeeze(test(2, :, :, :)),1),2); 14 | % x_result_3 = flip(flip(squeeze(test(3, :, :, :)),1),2); 15 | % x_result_4 = flip(flip(squeeze(test(4, :, :, :)),1),2); 16 | % x_result_5 = flip(flip(squeeze(test(5, :, :, :)),1),2); 17 | 18 | % x_result_1 = x_TSA_1; 19 | % x_result_2 = x_TSA_2; 20 | % x_result_3 = x_TSA_3; 21 | % x_result_4 = x_TSA_4; 22 | % x_result_5 = x_TSA_5; 23 | 24 | % x_result_1 = x_TSA_1; 25 | % x_result_2 = x_TSA_2; 26 | % x_result_3 = x_TSA_3; 27 | % x_result_4 = x_TSA_4; 28 | % x_result_5 = x_TSA_5; 29 | 30 | 31 | save_file = './real_results/results/rgb_results/'; 32 | mkdir(save_file); 33 | 34 | frame = 5; 35 | % for recon = {x_result_1,x_result_2,x_result_3,x_result_4,x_result_5} 36 | for recon = {x_result_5} 37 | recon = cell2mat(recon); 38 | intensity = 5; 39 | for channel=1:28 40 | img_nb = [channel]; % channel number 41 | row_num = 1; col_num = 1; 42 | lam28 = [453.5 457.5 462.0 466.0 471.5 476.5 481.5 487.0 492.5 498.0 504.0 510.0 516.0 522.5 529.5 536.5 544.0 551.5 558.5 567.5 575.5 584.5 594.5 604.0 614.5 625.0 636.5 648.0]; 43 | recon(find(recon>1))=1; 44 | name = [save_file 'frame' num2str(frame) 'channel' num2str(channel)]; 45 | dispCubeAshwin(recon(:,:,img_nb),intensity,lam28(img_nb), [] ,col_num,row_num,0,1,name); 46 | end 47 | frame = frame+1; 48 | end -------------------------------------------------------------------------------- /visualization/createfigure.m: -------------------------------------------------------------------------------- 1 | function createfigure(X1, YMatrix1, Corr) 2 | %CREATEFIGURE(X1, YMatrix1) 3 | % X1: x 数据的向量 4 | % YMATRIX1: y 数据的矩阵 5 | 6 | % 由 MATLAB 于 19-Feb-2022 11:12:35 自动生成 7 | 8 | % 创建 figure 9 | figure1 = figure('PaperOrientation','landscape',... 10 | 'PaperSize',[29.69999902 20.99999864]); 11 | 12 | % 创建 axes 13 | % axes1 = axes('Parent',figure1,'Position',[0.13 0.11 0.385625 0.815]); 14 | axes1 = axes('Parent',figure1,'Position',[0.15 0.15 0.8 0.8]); 15 | hold(axes1,'on'); 16 | 17 | % 使用 plot 的矩阵输入创建多行 18 | plot1 = plot(X1,YMatrix1,'MarkerSize',16,'Marker','.','LineWidth',2.5,... 19 | 'Parent',axes1); 20 | set(plot1(1),'DisplayName',' Ground Truth','Color',[1 0 0]); 21 | 22 | 23 | set(plot1(2),'DisplayName',' TwIST, corr: '+string(roundn(Corr(1),-4)),'Color',[0.01 0.33 0.62]); 24 | set(plot1(3),'DisplayName',' GAP-TV, corr: '+string(roundn(Corr(2),-4)),'Color',[0.01 0.66 0.62]); 25 | set(plot1(4),'DisplayName',' DeSCI, corr: '+string(roundn(Corr(3),-4)),'Color',[0.01 0.66 0.33]); 26 | set(plot1(5),'DisplayName',' TSANet, corr: '+string(roundn(Corr(4),-4)),'Color',[0.66 0.66 0.62]); 27 | set(plot1(6),'DisplayName',' GAP-Net, corr: '+string(roundn(Corr(5),-4)),'Color',[0.01 0.66 0.01]); 28 | set(plot1(7),'DisplayName',' MST-L, corr: '+string(roundn(Corr(6),-4)),'Color',[0 0.5 0.5]); 29 | set(plot1(8),'DisplayName',' DAUHST, corr: '+string(roundn(Corr(7),-4)),'Color', [1 1 0]); 30 | set(plot1(9),'DisplayName',' Ours 9stg , corr: '+string(roundn(Corr(8),-4)),'Color', [0 1 0]); 31 | set(plot1(10),'DisplayName',' Ours 9stg* , corr: '+string(roundn(Corr(8),-4)),'Color', [0 1 0]); 32 | 33 | % 取消以下行的注释以保留坐标区的 Y 范围 34 | ylim(axes1,[0 1]); 35 | box(axes1,'on'); 36 | hold(axes1,'off'); 37 | % 设置其余坐标区属性 38 | set(axes1,'FontName','Arial','FontSize',22,'LineWidth',3.5); 39 | 40 | % 创建 ylabel 41 | ylabel('Density','FontSize',28,'FontName','Arial'); 42 | 43 | % 创建 xlabel 44 | xlabel('Wavelength (nm)','FontSize',28,'FontName','Arial'); 45 | % 创建 legend 46 | legend1 = legend(axes1,'show'); 47 | % 'Position',[0.320670220276361 0.124725505873052 0.187369795342287 0.36915888702758],... 48 | set(legend1,... 49 | 'Position',[0.335 0.25 0.187369795342287 0.36915888702758],... 50 | 'FontSize',22,... 51 | 'EdgeColor',[1 1 1]); 52 | 53 | print(figure1,'-djpeg','-r300', "spectral_density.png") 54 | -------------------------------------------------------------------------------- /csi/config/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import functools 5 | import inspect 6 | import logging 7 | from fvcore.common.config import CfgNode as _CfgNode 8 | 9 | from csi.utils.file_io import PathManager 10 | 11 | class CfgNode(_CfgNode): 12 | """ 13 | The same as `fvcore.common.config.CfgNode`, but different in: 14 | 1. Use unsafe yaml loading by default. 15 | Note that this may lead to arbitrary code execution: you must not 16 | load a config file from untrusted sources before manually inspecting 17 | the content of the file. 18 | 2. Support config versioning. 19 | When attempting to merge an old config, it will convert the old config automatically. 20 | .. automethod:: clone 21 | .. automethod:: freeze 22 | .. automethod:: defrost 23 | .. automethod:: is_frozen 24 | .. automethod:: load_yaml_with_base 25 | .. automethod:: merge_from_list 26 | .. automethod:: merge_from_other_cfg 27 | """ 28 | 29 | @classmethod 30 | def _open_cfg(cls, filename): 31 | return PathManager.open(filename, "r") 32 | 33 | # Note that the default value of allow_unsafe is changed to True 34 | def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None: 35 | """ 36 | Load content from the given config file and merge it into self. 37 | Args: 38 | cfg_filename: config filename 39 | allow_unsafe: allow unsafe yaml syntax 40 | """ 41 | assert PathManager.isfile(cfg_filename), f"Config file '{cfg_filename}' does not exist!" 42 | loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe) 43 | loaded_cfg = type(self)(loaded_cfg) 44 | 45 | # defaults.py needs to import CfgNode 46 | from .defaults import _C 47 | 48 | logger = logging.getLogger(__name__) 49 | 50 | self.merge_from_other_cfg(loaded_cfg) 51 | 52 | 53 | def dump(self, *args, **kwargs): 54 | """ 55 | Returns: 56 | str: a yaml string representation of the config 57 | """ 58 | # to make it show up in docs 59 | return super().dump(*args, **kwargs) 60 | 61 | 62 | def get_cfg() -> CfgNode: 63 | """ 64 | Get a copy of the default config. 65 | Returns: 66 | a detectron2 CfgNode instance. 67 | """ 68 | from .defaults import _C 69 | 70 | return _C.clone() -------------------------------------------------------------------------------- /csi/utils/schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | from torch.optim import Optimizer, Adam 4 | from torch.optim.lr_scheduler import LambdaLR 5 | 6 | import numpy as np 7 | from tqdm import tqdm 8 | from matplotlib import pyplot as plt 9 | import seaborn as sns 10 | 11 | 12 | def get_cosine_schedule_with_warmup( 13 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1, eta_min=1e-6 14 | ): 15 | """ 16 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 17 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 18 | initial lr set in the optimizer. 19 | Args: 20 | optimizer ([`~torch.optim.Optimizer`]): 21 | The optimizer for which to schedule the learning rate. 22 | num_warmup_steps (`int`): 23 | The number of steps for the warmup phase. 24 | num_training_steps (`int`): 25 | The total number of training steps. 26 | num_cycles (`float`, *optional*, defaults to 0.5): 27 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 28 | following a half-cosine). 29 | last_epoch (`int`, *optional*, defaults to -1): 30 | The index of the last epoch when resuming training. 31 | Return: 32 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 33 | """ 34 | 35 | def lr_lambda(current_step): 36 | if current_step < num_warmup_steps: 37 | return float(current_step) / float(max(1, num_warmup_steps)) 38 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 39 | return max(eta_min, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 40 | 41 | return LambdaLR(optimizer, lr_lambda, last_epoch) 42 | 43 | 44 | if __name__ == "__main__": 45 | model = nn.Linear(512, 256) 46 | optimizer = Adam(model.parameters(), lr=2e-4) 47 | num_warmup_steps = int(np.floor(5000 / 1)) 48 | num_training_steps = int(np.floor(5000/ 1)) * 300 49 | scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) 50 | 51 | lrs = [] 52 | for i in tqdm(range(num_training_steps)): 53 | lrs.append(optimizer.state_dict()['param_groups'][0]['lr']) 54 | scheduler.step() 55 | 56 | sns.lineplot(x=range(len(lrs)), y=lrs) 57 | plt.savefig('lr.png') 58 | -------------------------------------------------------------------------------- /tools/test_real.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | # add python path of PadleDetection to sys.path 6 | parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) 7 | sys.path.insert(0, parent_path) 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from torch import optim 13 | from torch.utils.data import DataLoader 14 | from torch.nn.utils import clip_grad_norm_ 15 | from torchvision.utils import make_grid 16 | from torch_ema import ExponentialMovingAverage 17 | 18 | import cv2 19 | import numpy as np 20 | from scipy import io as sio 21 | from tqdm import tqdm 22 | 23 | from csi.config import get_cfg 24 | from csi.engine import default_argument_parser, default_setup 25 | from csi.data import CSITrainDataset, LoadVal, LoadTSATestMeas, shift_back_batch, generate_mask_3d, generate_mask_3d_shift, gen_meas_torch 26 | from csi.architectures import DERNN_LNLT 27 | from csi.utils.schedulers import get_cosine_schedule_with_warmup 28 | from csi.losses import CharbonnierLoss, TVLoss 29 | from csi.metrics import torch_psnr, torch_ssim, sam 30 | from csi.utils.utils import checkpoint 31 | 32 | args = default_argument_parser().parse_args() 33 | cfg = get_cfg() 34 | cfg.merge_from_file(args.config_file) 35 | cfg.freeze() 36 | 37 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 38 | 39 | mask_test = generate_mask_3d_shift(mask_path=cfg.DATASETS.TEST.MASK_PATH).to(device) 40 | 41 | test_meas = LoadTSATestMeas(cfg.DATASETS.TEST.PATH).to(device) 42 | 43 | model = eval(cfg.MODEL.DENOISER.TYPE)(cfg).to(device) 44 | 45 | ema = ExponentialMovingAverage(model.parameters(), decay=cfg.MODEL.EMA.DECAY) 46 | 47 | if cfg.PRETRAINED_CKPT_PATH: 48 | print(f"===> Loading Checkpoint from {cfg.PRETRAINED_CKPT_PATH}") 49 | save_state = torch.load(cfg.PRETRAINED_CKPT_PATH, map_location=device) 50 | model.load_state_dict(save_state['model']) 51 | ema.load_state_dict(save_state['ema']) 52 | 53 | def test(test_meas, name="test_a"): 54 | model.eval() 55 | model_out = [] 56 | data = {} 57 | data['Y'] = test_meas / test_meas.max() * 0.8 58 | 59 | B, _, _ = test_meas.shape 60 | data['mask'] = mask_test.unsqueeze(0).tile((B, 1, 1, 1)) 61 | data['H'] = shift_back_batch(test_meas, step=cfg.DATASETS.STEP, nC=cfg.DATASETS.WAVE_LENS) 62 | 63 | with torch.no_grad(): 64 | with ema.average_parameters(): 65 | model_out = model(data) 66 | 67 | 68 | for i in range(B): 69 | out_plot = F.interpolate(model_out[i:i+1, :, :, :], size=(128, 128)) 70 | if name == "TSA": out_plot = torch.flip(out_plot, dims=(2, 3)) 71 | 72 | 73 | model_out = np.transpose(model_out.detach().cpu().numpy(), (0, 2, 3, 1)).astype(np.float32) 74 | model.train() 75 | 76 | return model_out 77 | 78 | 79 | def main(): 80 | test_out = test(test_meas, "TSA") 81 | sio.savemat("./results/dernn_lnlt_5stg_real.mat", {"pred": test_out}) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() -------------------------------------------------------------------------------- /csi/config/defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .config import CfgNode as CN 3 | 4 | # NOTE: given the new config system 5 | # (https://detectron2.readthedocs.io/en/latest/tutorials/lazyconfigs.html), 6 | # we will stop adding new functionalities to default CfgNode. 7 | 8 | # ----------------------------------------------------------------------------- 9 | # Convention about Training / Test specific parameters 10 | # ----------------------------------------------------------------------------- 11 | # Whenever an argument can be either used for training or for testing, the 12 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 13 | # or _TEST for a test-specific parameter. 14 | # For example, the number of images during training will be 15 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be 16 | # IMAGES_PER_BATCH_TEST 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Config definition 20 | # ----------------------------------------------------------------------------- 21 | 22 | _C = CN() 23 | 24 | _C.IMG_SIZE = [256, 256] 25 | _C.DEBUG = False 26 | _C.OUTPUT_DIR = "" 27 | _C.SEED = 3407 28 | _C.DETERMINISTIC = True 29 | _C.RESUME_CKPT_PATH = None 30 | _C.PRETRAINED_CKPT_PATH = None 31 | 32 | _C.DATASETS = CN() 33 | _C.DATASETS.STEP = 2 34 | _C.DATASETS.WAVE_LENS = 28 35 | _C.DATASETS.MASK_TYPE = "mask_3d" 36 | 37 | _C.DATASETS.TRAIN = CN() 38 | _C.DATASETS.TRAIN.MASK_PATH = "" 39 | _C.DATASETS.TRAIN.RANDOM_MASK = True 40 | _C.DATASETS.TRAIN.PATHS = [] 41 | _C.DATASETS.TRAIN.ITERATION = 1000 42 | _C.DATASETS.TRAIN.CROP_SIZE = [256, 256] 43 | _C.DATASETS.TRAIN.AUGMENT = True 44 | _C.DATASETS.TRAIN.WITH_NOISE = True 45 | 46 | _C.DATASETS.VAL = CN() 47 | _C.DATASETS.VAL.MASK_PATH = "" 48 | _C.DATASETS.VAL.PATH = "" 49 | 50 | 51 | _C.DATASETS.TEST = CN() 52 | _C.DATASETS.TEST.PATH = None 53 | _C.DATASETS.TEST.MASK_PATH = None 54 | 55 | # DATALOADER 56 | _C.DATALOADER = CN() 57 | _C.DATALOADER.BATCH_SIZE = 4 58 | _C.DATALOADER.NUM_WORKERS = 8 59 | _C.DATALOADER.PIN_MEMORY = False 60 | 61 | 62 | # MODEL 63 | _C.MODEL = CN() 64 | _C.MODEL.DENOISER = CN() 65 | _C.MODEL.DENOISER.TYPE = "DERNN_LNLT" 66 | 67 | _C.MODEL.DENOISER.DERNN_LNLT = CN() 68 | _C.MODEL.DENOISER.DERNN_LNLT.LOCAL = True 69 | _C.MODEL.DENOISER.DERNN_LNLT.NON_LOCAL = True 70 | _C.MODEL.DENOISER.DERNN_LNLT.IN_DIM = 28 71 | _C.MODEL.DENOISER.DERNN_LNLT.OUT_DIM = 28 72 | _C.MODEL.DENOISER.DERNN_LNLT.DIM = 28 73 | _C.MODEL.DENOISER.DERNN_LNLT.WINDOW_SIZE = [8, 8] 74 | _C.MODEL.DENOISER.DERNN_LNLT.WINDOW_NUM = [8, 8] 75 | _C.MODEL.DENOISER.DERNN_LNLT.NUM_BLOCKS = [1, 1, 1, 1, 1] 76 | _C.MODEL.DENOISER.DERNN_LNLT.FFN_NAME = "Gated_Dconv_FeedForward" 77 | _C.MODEL.DENOISER.DERNN_LNLT.FFN_EXPAND = 2.66 78 | _C.MODEL.DENOISER.DERNN_LNLT.LAYERNORM_TYPE = "WithBias" 79 | _C.MODEL.DENOISER.DERNN_LNLT.STAGE = 5 80 | _C.MODEL.DENOISER.DERNN_LNLT.SHARE_PARAMS = True 81 | _C.MODEL.DENOISER.DERNN_LNLT.WITH_DL = True 82 | _C.MODEL.DENOISER.DERNN_LNLT.WITH_MU = True 83 | _C.MODEL.DENOISER.DERNN_LNLT.WITH_NOISE_LEVEL = True 84 | 85 | 86 | 87 | 88 | # EMA 89 | _C.MODEL.EMA = CN() 90 | _C.MODEL.EMA.ENABLE = True 91 | _C.MODEL.EMA.DECAY = 0.999 92 | 93 | # OPTIMIZER 94 | _C.OPTIMIZER = CN() 95 | _C.OPTIMIZER.MAX_EPOCH = 300 96 | _C.OPTIMIZER.LR = 2e-4 97 | _C.OPTIMIZER.GRAD_CLIP = True 98 | 99 | _C.LOSSES = CN() 100 | _C.LOSSES.L1_LOSS = True 101 | _C.LOSSES.TV_LOSS = False 102 | -------------------------------------------------------------------------------- /csi/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 10 | return gauss / gauss.sum() 11 | 12 | 13 | def create_window(window_size, channel): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 17 | return window 18 | 19 | 20 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 21 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 22 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 23 | 24 | mu1_sq = mu1.pow(2) 25 | mu2_sq = mu2.pow(2) 26 | mu1_mu2 = mu1 * mu2 27 | 28 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 29 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 30 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 31 | 32 | C1 = 0.01 ** 2 33 | C2 = 0.03 ** 2 34 | 35 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 36 | 37 | if size_average: 38 | return ssim_map.mean() 39 | else: 40 | return ssim_map.mean(1).mean(1).mean(1) 41 | 42 | 43 | class SSIM(torch.nn.Module): 44 | def __init__(self, window_size=11, size_average=True): 45 | super(SSIM, self).__init__() 46 | self.window_size = window_size 47 | self.size_average = size_average 48 | self.channel = 1 49 | self.window = create_window(window_size, self.channel) 50 | 51 | def forward(self, img1, img2): 52 | (_, channel, _, _) = img1.size() 53 | 54 | if channel == self.channel and self.window.data.type() == img1.data.type(): 55 | window = self.window 56 | else: 57 | window = create_window(self.window_size, channel) 58 | 59 | if img1.is_cuda: 60 | window = window.cuda(img1.get_device()) 61 | window = window.type_as(img1) 62 | 63 | self.window = window 64 | self.channel = channel 65 | 66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 67 | 68 | 69 | def ssim(img1, img2, window_size=11, size_average=True): 70 | (_, channel, _, _) = img1.size() 71 | window = create_window(window_size, channel) 72 | 73 | if img1.is_cuda: 74 | window = window.cuda(img1.get_device()) 75 | window = window.type_as(img1) 76 | 77 | return _ssim(img1, img2, window, window_size, channel, size_average) 78 | 79 | 80 | # We find that this calculation method is more close to DGSMP's. 81 | def torch_psnr(img, ref): # input [28,256,256] 82 | img = (img*256).round() 83 | ref = (ref*256).round() 84 | nC = img.shape[0] 85 | psnr = 0 86 | for i in range(nC): 87 | mse = torch.mean((img[i, :, :] - ref[i, :, :]) ** 2) 88 | psnr += 10 * torch.log10((255*255)/mse) 89 | return psnr / nC 90 | 91 | def torch_ssim(img, ref): # input [28,256,256] 92 | return ssim(torch.unsqueeze(img, 0), torch.unsqueeze(ref, 0)) 93 | 94 | 95 | def sam(x_true, x_pred): 96 | """ 97 | :param x_true: 高光谱图像:格式:(H, W, C) 98 | :param x_pred: 高光谱图像:格式:(H, W, C) 99 | :return: 计算原始高光谱数据与重构高光谱数据的光谱角相似度 100 | """ 101 | num = 0 102 | sum_sam = 0 103 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 104 | for x in range(x_true.shape[0]): 105 | for y in range(x_true.shape[1]): 106 | tmp_pred = x_pred[x, y].ravel() 107 | tmp_true = x_true[x, y].ravel() 108 | if np.linalg.norm(tmp_true) != 0 and np.linalg.norm(tmp_pred) != 0: 109 | sum_sam += np.arccos( 110 | np.inner(tmp_pred, tmp_true) / (np.linalg.norm(tmp_true) * np.linalg.norm(tmp_pred))) 111 | num += 1 112 | sam_deg = (sum_sam / num) * 180 / np.pi 113 | # 114 | return sam_deg -------------------------------------------------------------------------------- /tools/test_simu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | # add python path of PadleDetection to sys.path 6 | parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) 7 | sys.path.insert(0, parent_path) 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from torch import optim 13 | from torch.utils.data import DataLoader 14 | from torch.nn.utils import clip_grad_norm_ 15 | from torchvision.utils import make_grid 16 | from torch_ema import ExponentialMovingAverage 17 | 18 | import cv2 19 | import numpy as np 20 | from scipy import io as sio 21 | from tqdm import tqdm 22 | 23 | from csi.config import get_cfg 24 | from csi.engine import default_argument_parser, default_setup 25 | from csi.data import CSITrainDataset, LoadVal, LoadTSATestMeas, shift_back_batch, generate_mask_3d, generate_mask_3d_shift, gen_meas_torch 26 | from csi.architectures import DERNN_LNLT 27 | from csi.utils.schedulers import get_cosine_schedule_with_warmup 28 | from csi.losses import CharbonnierLoss, TVLoss 29 | from csi.metrics import torch_psnr, torch_ssim, sam 30 | from csi.utils.utils import checkpoint 31 | 32 | args = default_argument_parser().parse_args() 33 | cfg = get_cfg() 34 | cfg.merge_from_file(args.config_file) 35 | cfg.freeze() 36 | 37 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 38 | 39 | mask = generate_mask_3d_shift(mask_path=cfg.DATASETS.VAL.MASK_PATH).to(device) 40 | 41 | val_datas = LoadVal(cfg.DATASETS.VAL.PATH) 42 | 43 | model = eval(cfg.MODEL.DENOISER.TYPE)(cfg).to(device) 44 | 45 | ema = ExponentialMovingAverage(model.parameters(), decay=cfg.MODEL.EMA.DECAY) 46 | 47 | if cfg.PRETRAINED_CKPT_PATH: 48 | print(f"===> Loading Checkpoint from {cfg.PRETRAINED_CKPT_PATH}") 49 | save_state = torch.load(cfg.PRETRAINED_CKPT_PATH, map_location=device) 50 | model.load_state_dict(save_state['model']) 51 | ema.load_state_dict(save_state['ema']) 52 | 53 | 54 | 55 | def eval(): 56 | psnr_list, ssim_list, sam_list = [], [], [] 57 | val_H = [] 58 | val_Y = [] 59 | val_gt = [] 60 | for val_label in val_datas['hsi']: 61 | val_label = torch.from_numpy(val_label).permute(2, 0, 1).to(device).float() 62 | YH = gen_meas_torch(val_label, mask, step=cfg.DATASETS.STEP, wave_len=cfg.DATASETS.WAVE_LENS, mask_type=cfg.DATASETS.MASK_TYPE) 63 | val_H.append(YH['H'].to(device)) 64 | val_Y.append(YH['Y'].to(device)) 65 | val_gt.append(val_label) 66 | val_gt = torch.stack(val_gt) 67 | val_H = torch.stack(val_H) 68 | val_Y = torch.stack(val_Y) 69 | data = {} 70 | data['hsi'] = val_gt 71 | data['H'] = val_H 72 | B, _, _, _ = val_H.shape 73 | data['mask'] = mask.unsqueeze(0).tile((B, 1, 1, 1)) 74 | data['Y'] = val_Y 75 | 76 | model.eval() 77 | begin = time.time() 78 | with torch.no_grad(): 79 | with ema.average_parameters(): 80 | out = model(data) 81 | model_out = out 82 | 83 | for i in range(len(model_out)): 84 | psnr_val = torch_psnr(model_out[i, :, :, :], val_gt[i, :, :, :]) 85 | ssim_val = torch_ssim(model_out[i, :, :, :], val_gt[i, :, :, :]) 86 | sam_val = sam(model_out[i, :, :, :].permute(1, 2, 0).cpu().numpy(), val_gt[i, :, :, :].permute(1, 2, 0).cpu().numpy()) 87 | psnr_list.append(psnr_val.detach().cpu().numpy()) 88 | ssim_list.append(ssim_val.detach().cpu().numpy()) 89 | sam_list.append(sam_val) 90 | 91 | pred = np.transpose(model_out.detach().cpu().numpy(), (0, 2, 3, 1)).astype(np.float32) 92 | truth = np.transpose(val_gt.cpu().numpy(), (0, 2, 3, 1)).astype(np.float32) 93 | psnr_mean = np.mean(np.asarray(psnr_list)) 94 | ssim_mean = np.mean(np.asarray(ssim_list)) 95 | sam_mean = np.mean(np.asarray(sam_list)) 96 | 97 | end = time.time() 98 | 99 | print('===> testing psnr = {:.2f}, ssim = {:.3f}, sam = {:.3f}, time: {:.2f}' 100 | .format(psnr_mean, ssim_mean, sam_mean, (end - begin))) 101 | model.train() 102 | return pred, truth, psnr_list, ssim_list, sam_list, psnr_mean, ssim_mean, sam_mean 103 | 104 | 105 | 106 | def main(): 107 | (pred, truth, psnr_all, ssim_all, sam_all, psnr_mean, ssim_mean, sam_mean) = eval() 108 | sio.savemat("./results/dernn_lnlt_9stg_star_simu.mat", {"pred": pred, "truth" : truth}) 109 | 110 | 111 | 112 | if __name__ == "__main__": 113 | main() -------------------------------------------------------------------------------- /Quality_Metrics/img_qi.m: -------------------------------------------------------------------------------- 1 | function [quality, quality_map] = img_qi(img1, img2, block_size) 2 | 3 | %======================================================================== 4 | % 5 | %Copyright (c) 2001 The University of Texas at Austin 6 | %All Rights Reserved. 7 | % 8 | %This program is free software; you can redistribute it and/or modify 9 | %it under the terms of the GNU General Public License as published by 10 | %the Free Software Foundation; either version 2 of the License, or 11 | %(at your option) any later version. 12 | % 13 | %This program is distributed in the hope that it will be useful, 14 | %but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | %MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | %GNU General Public License for more details. 17 | % 18 | %The GNU Public License is available in the file LICENSE, or you 19 | %can write to the Free Software Foundation, Inc., 59 Temple Place - 20 | %Suite 330, Boston, MA 02111-1307, USA, or you can find it on the 21 | %World Wide Web at http://www.fsf.org. 22 | % 23 | %Author : Zhou Wang 24 | %Version : 1.0 25 | % 26 | %The authors are with the Laboratory for Image and Video Engineering 27 | %(LIVE), Department of Electrical and Computer Engineering, The 28 | %University of Texas at Austin, Austin, TX. 29 | % 30 | %Kindly report any suggestions or corrections to zwang@ece.utexas.edu 31 | % 32 | %Acknowledgement: 33 | %The author would like to thank Mr. Umesh Rajashekar, the Matlab master 34 | %in our lab, for spending his precious time and giving his kind help 35 | %on writing this program. Without his help, this program would not 36 | %achieve its current efficiency. 37 | % 38 | %======================================================================== 39 | % 40 | %This is an efficient implementation of the algorithm for calculating 41 | %the universal image quality index proposed by Zhou Wang and Alan C. 42 | %Bovik. Please refer to the paper "A Universal Image Quality Index" 43 | %by Zhou Wang and Alan C. Bovik, published in IEEE Signal Processing 44 | %Letters, 2001. In order to run this function, you must have Matlab's 45 | %Image Processing Toobox. 46 | % 47 | %Input : an original image and a test image of the same size 48 | %Output: (1) an overall quality index of the test image, with a value 49 | % range of [-1, 1]. 50 | % (2) a quality map of the test image. The map has a smaller 51 | % size than the input images. The actual size is 52 | % img_size - BLOCK_SIZE + 1. 53 | % 54 | %Usage: 55 | % 56 | %1. Load the original and the test images into two matrices 57 | % (say img1 and img2) 58 | % 59 | %2. Run this function in one of the two ways: 60 | % 61 | % % Choice 1 (suggested): 62 | % [qi qi_map] = img_qi(img1, img2); 63 | % 64 | % % Choice 2: 65 | % [qi qi_map] = img_qi(img1, img2, BLOCK_SIZE); 66 | % 67 | % The default BLOCK_SIZE is 8 (Choice 1). Otherwise, you can specify 68 | % it by yourself (Choice 2). 69 | % 70 | %3. See the results: 71 | % 72 | % qi %Gives the over quality index. 73 | % imshow((qi_map+1)/2) %Shows the quality map as an image. 74 | % 75 | %======================================================================== 76 | 77 | if (nargin == 1 | nargin > 3) 78 | quality = -Inf; 79 | quality_map = -1*ones(size(img1)); 80 | return; 81 | end 82 | 83 | if (size(img1) ~= size(img2)) 84 | quality = -Inf; 85 | quality_map = -1*ones(size(img1)); 86 | return; 87 | end 88 | 89 | if (nargin == 2) 90 | block_size = 8; 91 | end 92 | 93 | N = block_size.^2; 94 | sum2_filter = ones(block_size); 95 | 96 | img1_sq = img1.*img1; 97 | img2_sq = img2.*img2; 98 | img12 = img1.*img2; 99 | 100 | img1_sum = filter2(sum2_filter, img1, 'valid'); 101 | img2_sum = filter2(sum2_filter, img2, 'valid'); 102 | img1_sq_sum = filter2(sum2_filter, img1_sq, 'valid'); 103 | img2_sq_sum = filter2(sum2_filter, img2_sq, 'valid'); 104 | img12_sum = filter2(sum2_filter, img12, 'valid'); 105 | 106 | img12_sum_mul = img1_sum.*img2_sum; 107 | img12_sq_sum_mul = img1_sum.*img1_sum + img2_sum.*img2_sum; 108 | numerator = 4*(N*img12_sum - img12_sum_mul).*img12_sum_mul; 109 | denominator1 = N*(img1_sq_sum + img2_sq_sum) - img12_sq_sum_mul; 110 | denominator = denominator1.*img12_sq_sum_mul; 111 | 112 | quality_map = ones(size(denominator)); 113 | index = (denominator1 == 0) & (img12_sq_sum_mul ~= 0); 114 | quality_map(index) = 2*img12_sum_mul(index)./img12_sq_sum_mul(index); 115 | index = (denominator ~= 0); 116 | quality_map(index) = numerator(index)./denominator(index); 117 | 118 | quality = mean2(quality_map); -------------------------------------------------------------------------------- /visualization/show_line.m: -------------------------------------------------------------------------------- 1 | %% plot color pics 2 | clear; clc; 3 | load(['./simulation_results/results/','truth','.mat']); 4 | 5 | load(['./simulation_results/results/','TwIST','.mat']); 6 | pred_block_twist = pred; 7 | 8 | load(['./simulation_results/results/','GAP-TV','.mat']); 9 | pred_block_gaptv = pred; 10 | 11 | load(['./simulation_results/results/','DeSCI','.mat']); 12 | pred_block_desci = pred; 13 | 14 | load(['./simulation_results/results/','tsa_net','.mat']); 15 | pred_block_tsanet = pred; 16 | 17 | load(['./simulation_results/results/','gap_net','.mat']); 18 | pred_block_gapnet = pred; 19 | 20 | load(['./simulation_results/results/','mst_l','.mat']); 21 | pred_block_mst_l = pred; 22 | 23 | load(['./simulation_results/results/','dauhst_9stg','.mat']); 24 | pred_block_dauhst = pred; 25 | 26 | load(['./simulation_results/results/','DERNN_LNLT_9stg','.mat']); 27 | pred_block_dernn_lnlt_9stg = pred; 28 | 29 | load(['./simulation_results/results/','DERNN_LNLT_9stg_plus','.mat']); 30 | pred_block_dernn_lnlt_9stg_plus = pred; 31 | 32 | lam28 = [453.5 457.5 462.0 466.0 471.5 476.5 481.5 487.0 492.5 498.0 504.0 510.0... 33 | 516.0 522.5 529.5 536.5 544.0 551.5 558.5 567.5 575.5 584.5 594.5 604.0... 34 | 614.5 625.0 636.5 648.0]; 35 | 36 | truth(find(truth>0.7))=0.7; 37 | pred_block_twist(find(pred_block_twist>0.7))=0.7; 38 | pred_block_gaptv(find(pred_block_gaptv>0.7))=0.7; 39 | pred_block_desci(find(pred_block_desci>0.7))=0.7; 40 | pred_block_tsanet(find(pred_block_tsanet>0.7))=0.7; 41 | pred_block_gapnet(find(pred_block_gapnet>0.7))=0.7; 42 | pred_block_mst_l(find(pred_block_mst_l>0.7))=0.7; 43 | pred_block_dauhst(find(pred_block_dauhst>0.7))=0.7; 44 | pred_block_dernn_lnlt_9stg(find(pred_block_dernn_lnlt_9stg>0.7))=0.7; 45 | pred_block_dernn_lnlt_9stg_plus(find(pred_block_dernn_lnlt_9stg_plus>0.7))=0.7; 46 | 47 | 48 | f = 8; 49 | 50 | %% plot spectrum 51 | figure(123); 52 | [yx, rect2crop]=imcrop(sum(squeeze(truth(f, :, :, :)), 3), [40 50 40 40]); 53 | rect2crop=round(rect2crop) 54 | % close(123); 55 | imshow(yx / 28) 56 | figure; 57 | 58 | spec_mean_truth = mean(mean(squeeze(truth(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 59 | spec_mean_twist = mean(mean(squeeze(pred_block_twist(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 60 | spec_mean_gaptv = mean(mean(squeeze(pred_block_gaptv(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 61 | spec_mean_desci = mean(mean(squeeze(pred_block_desci(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 62 | spec_mean_tsanet = mean(mean(squeeze(pred_block_tsanet(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 63 | spec_mean_gapnet = mean(mean(squeeze(pred_block_gapnet(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 64 | spec_mean_mst_l = mean(mean(squeeze(pred_block_mst_l(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 65 | spec_mean_dauhst = mean(mean(squeeze(pred_block_dauhst(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 66 | spec_mean_dernn_lnlt_9stg = mean(mean(squeeze(pred_block_dernn_lnlt_9stg(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 67 | spec_mean_dernn_lnlt_9stg_plus = mean(mean(squeeze(pred_block_dernn_lnlt_9stg_plus(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 68 | 69 | 70 | spec_mean_truth = spec_mean_truth./max(spec_mean_truth); 71 | spec_mean_twist = spec_mean_twist./max(spec_mean_twist); 72 | spec_mean_gaptv = spec_mean_gaptv./max(spec_mean_gaptv); 73 | spec_mean_desci = spec_mean_desci./max(spec_mean_desci); 74 | spec_mean_tsanet = spec_mean_tsanet./max(spec_mean_tsanet); 75 | spec_mean_gapnet = spec_mean_gapnet./max(spec_mean_gapnet); 76 | spec_mean_mst_l = spec_mean_mst_l./max(spec_mean_mst_l); 77 | spec_mean_dauhst = spec_mean_dauhst./max(spec_mean_dauhst); 78 | spec_mean_dernn_lnlt_9stg = spec_mean_dernn_lnlt_9stg./max(spec_mean_dernn_lnlt_9stg); 79 | spec_mean_dernn_lnlt_9stg_plus = spec_mean_dernn_lnlt_9stg_plus./max(spec_mean_dernn_lnlt_9stg_plus); 80 | 81 | 82 | corr_twist = roundn(corr(spec_mean_truth(:),spec_mean_twist(:)),-4); 83 | corr_gaptv = roundn(corr(spec_mean_truth(:),spec_mean_gaptv(:)),-4); 84 | corr_desci = roundn(corr(spec_mean_truth(:),spec_mean_desci(:)),-4); 85 | corr_tsanet = roundn(corr(spec_mean_truth(:),spec_mean_tsanet(:)),-4); 86 | corr_gapnet = roundn(corr(spec_mean_truth(:),spec_mean_gapnet(:)),-4); 87 | corr_mst_l = roundn(corr(spec_mean_truth(:),spec_mean_mst_l(:)),-4); 88 | corr_dauhst = roundn(corr(spec_mean_truth(:),spec_mean_dauhst(:)),-4); 89 | corr_dernn_lnlt_9stg = roundn(corr(spec_mean_truth(:),spec_mean_dernn_lnlt_9stg(:)),-4); 90 | corr_dernn_lnlt_9stg_plus = roundn(corr(spec_mean_truth(:),spec_mean_dernn_lnlt_9stg_plus(:)),-4); 91 | 92 | 93 | 94 | X = lam28; 95 | 96 | Y(1,:) = spec_mean_truth(:); 97 | Y(2,:) = spec_mean_twist(:); Corr(1)=corr_twist; 98 | Y(3,:) = spec_mean_gaptv(:); Corr(2)=corr_gaptv; 99 | Y(4,:) = spec_mean_desci(:); Corr(3)=corr_desci; 100 | Y(5,:) = spec_mean_tsanet(:); Corr(4)=corr_tsanet; 101 | Y(6,:) = spec_mean_gapnet(:); Corr(5)=corr_gapnet; 102 | Y(7,:) = spec_mean_mst_l(:); Corr(6)=corr_mst_l; 103 | Y(8,:) = spec_mean_dauhst(:); Corr(7)=corr_dauhst; 104 | Y(9,:) = spec_mean_dernn_lnlt_9stg(:); Corr(8)=corr_dernn_lnlt_9stg; 105 | Y(10,:) = spec_mean_dernn_lnlt_9stg_plus(:); Corr(9)=corr_dernn_lnlt_9stg_plus; 106 | 107 | 108 | 109 | createfigure(X,Y,Corr) 110 | 111 | 112 | -------------------------------------------------------------------------------- /visualization/show_line.asv: -------------------------------------------------------------------------------- 1 | %% plot color pics 2 | clear; clc; 3 | load(['./simulation_results/results/','truth','.mat']); 4 | 5 | % load(['./simulation_results/results/','hdnet','.mat']); 6 | % pred_block_hdnet = pred; 7 | 8 | % load(['./simulation_results/results/','mst_s','.mat']); 9 | % pred_block_mst_s = pred; 10 | % 11 | % load(['./simulation_results/results/','mst_m','.mat']); 12 | % pred_block_mst_m = pred; 13 | 14 | % load(['./simulation_results/results/','mst_l','.mat']); 15 | % pred_block_mst_l = pred; 16 | 17 | % load(['./simulation_results/results/','mst_plus_plus','.mat']); 18 | % pred_block_mst_plus_plus = pred; 19 | 20 | load(['./simulation_results/results/','TwIST','.mat']); 21 | pred_block_twist = pred; 22 | 23 | load(['./simulation_results/results/','GAP-TV','.mat']); 24 | pred_block_gaptv = pred; 25 | 26 | load(['./simulation_results/results/','DeSCI','.mat']); 27 | pred_block_desci = pred; 28 | 29 | load(['./simulation_results/results/','tsa_net','.mat']); 30 | pred_block_tsanet = pred; 31 | 32 | load(['./simulation_results/results/','gap_net','.mat']); 33 | pred_block_gapnet = pred; 34 | 35 | load(['./simulation_results/results/','mst_l','.mat']); 36 | pred_block_mst_l = pred; 37 | 38 | load(['./simulation_results/results/','dauhst_9stg','.mat']); 39 | pred_block_dauhst = pred; 40 | 41 | load(['./simulation_results/results/','DERNN_LNLT_9stg','.mat']); 42 | pred_block_dernn_lnlt_9stg = test; 43 | 44 | load(['./simulation_results/results/','DERNN_LNLT_9stg_plus','.mat']); 45 | pred_block_dernn_lnlt_9stg_plus = test; 46 | 47 | lam28 = [453.5 457.5 462.0 466.0 471.5 476.5 481.5 487.0 492.5 498.0 504.0 510.0... 48 | 516.0 522.5 529.5 536.5 544.0 551.5 558.5 567.5 575.5 584.5 594.5 604.0... 49 | 614.5 625.0 636.5 648.0]; 50 | 51 | truth(find(truth>0.7))=0.7; 52 | pred_block_twist(find(pred_block_twist>0.7))=0.7; 53 | pred_block_gaptv(find(pred_block_gaptv>0.7))=0.7; 54 | pred_block_desci(find(pred_block_desci>0.7))=0.7; 55 | pred_block_tsanet(find(pred_block_tsanet>0.7))=0.7; 56 | pred_block_gapnet(find(pred_block_gapnet>0.7))=0.7; 57 | pred_block_mst_l(find(pred_block_mst_l>0.7))=0.7; 58 | pred_block_dauhst(find(pred_block_dauhst>0.7))=0.7; 59 | pred_block_dernn_lnlt_9stg(find(pred_block_dernn_lnlt_9stg>0.7))=0.7; 60 | pred_block_dernn_lnlt_9stg_plus(find(pred_block_dernn_lnlt_9stg_plus>0.7))=0.7; 61 | 62 | 63 | f = 5; 64 | 65 | %% plot spectrum 66 | figure(123); 67 | [yx, rect2crop]=imcrop(sum(squeeze(truth(f, :, :, :)), 3), [170 130 30 30]); 68 | rect2crop=round(rect2crop) 69 | % close(123); 70 | imshow(yx / 28) 71 | figure; 72 | 73 | spec_mean_truth = mean(mean(squeeze(truth(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 74 | spec_mean_twist = mean(mean(squeeze(pred_block_twist(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 75 | spec_mean_gaptv = mean(mean(squeeze(pred_block_gaptv(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 76 | spec_mean_desci = mean(mean(squeeze(pred_block_desci(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 77 | spec_mean_tsanet = mean(mean(squeeze(pred_block_tsanet(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 78 | spec_mean_gapnet = mean(mean(squeeze(pred_block_gapnet(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 79 | spec_mean_mst_l = mean(mean(squeeze(pred_block_mst_l(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 80 | spec_mean_dauhst = mean(mean(squeeze(pred_block_dauhst(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 81 | spec_mean_dernn_lnlt_9stg = mean(mean(squeeze(pred_block_dernn_lnlt_9stg(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 82 | spec_mean_dernn_lnlt_9stg_plus = mean(mean(squeeze(pred_block_dernn_lnlt_9stg_plus(f,rect2crop(2):rect2crop(2)+rect2crop(4) , rect2crop(1):rect2crop(1)+rect2crop(3),:)),1),2); 83 | 84 | 85 | spec_mean_truth = spec_mean_truth./max(spec_mean_truth); 86 | spec_mean_twist = spec_mean_twist./max(spec_mean_twist); 87 | spec_mean_gaptv = spec_mean_gaptv./max(spec_mean_gaptv); 88 | spec_mean_desci = spec_mean_desci./max(spec_mean_desci); 89 | spec_mean_tsanet = spec_mean_tsanet./max(spec_mean_tsanet); 90 | spec_mean_gapnet = spec_mean_gapnet./max(spec_mean_gapnet); 91 | spec_mean_mst_l = spec_mean_mst_l./max(spec_mean_mst_l); 92 | spec_mean_dauhst = spec_mean_dauhst./max(spec_mean_dauhst); 93 | spec_mean_dernn_lnlt_9stg = spec_mean_dernn_lnlt_9stg./max(spec_mean_dernn_lnlt_9stg); 94 | spec_mean_dernn_lnlt_9stg_plus = spec_mean_dernn_lnlt_9stg_plus./max(spec_mean_dernn_lnlt_9stg_plus); 95 | 96 | 97 | corr_twist = roundn(corr(spec_mean_truth(:),spec_mean_twist(:)),-4); 98 | corr_gaptv = roundn(corr(spec_mean_truth(:),spec_mean_gaptv(:)),-4); 99 | corr_desci = roundn(corr(spec_mean_truth(:),spec_mean_desci(:)),-4); 100 | corr_tsanet = roundn(corr(spec_mean_truth(:),spec_mean_tsanet(:)),-4); 101 | corr_gapnet = roundn(corr(spec_mean_truth(:),spec_mean_gapnet(:)),-4); 102 | corr_mst_l = roundn(corr(spec_mean_truth(:),spec_mean_mst_l(:)),-4); 103 | corr_dauhst = roundn(corr(spec_mean_truth(:),spec_mean_cst_l_plus(:)),-4); 104 | corr_dluf_mixs2 = roundn(corr(spec_mean_truth(:),spec_mean_dluf_mixs2(:)),-4); 105 | 106 | 107 | X = lam28; 108 | 109 | Y(1,:) = spec_mean_truth(:); 110 | % Y(2,:) = spec_mean_hdnet(:); Corr(1)=corr_hdnet; 111 | % Y(3,:) = spec_mean_mst_s(:); Corr(2)=corr_mst_s; 112 | % Y(4,:) = spec_mean_mst_m(:); Corr(3)=corr_mst_m; 113 | % Y(5,:) = spec_mean_mst_l(:); Corr(4)=corr_mst_l; 114 | % Y(6,:) = spec_mean_mst_plus_plus(:); Corr(5)=corr_mst_plus_plus; 115 | Y(2,:) = spec_mean_twist(:); Corr(1)=corr_twist; 116 | Y(3,:) = spec_mean_gaptv(:); Corr(2)=corr_gaptv; 117 | Y(4,:) = spec_mean_desci(:); Corr(3)=corr_desci; 118 | Y(5,:) = spec_mean_dgsmp(:); Corr(4)=corr_dgsmp; 119 | Y(6,:) = spec_mean_hdnet(:); Corr(5)=corr_hdnet; 120 | Y(7,:) = spec_mean_mst_l(:); Corr(6)=corr_mst_l; 121 | Y(8,:) = spec_mean_cst_l_plus(:); Corr(7)=corr_cst_l_plus; 122 | Y(9,:) = spec_mean_dluf_mixs2(:); Corr(8)=corr_dluf_mixs2; 123 | 124 | 125 | createfigure(X,Y,Corr) 126 | 127 | 128 | -------------------------------------------------------------------------------- /csi/engine/defaults.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import logging 5 | import datetime 6 | import random 7 | from omegaconf import OmegaConf 8 | 9 | import torch 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | 13 | import numpy as np 14 | 15 | from csi.config import CfgNode 16 | from csi.utils.file_io import PathManager 17 | 18 | def default_argument_parser(epilog=None): 19 | """ 20 | Create a parser with some common arguments used by detectron2 users. 21 | Args: 22 | epilog (str): epilog passed to ArgumentParser describing the usage. 23 | Returns: 24 | argparse.ArgumentParser: 25 | """ 26 | parser = argparse.ArgumentParser( 27 | epilog=epilog 28 | or f""" 29 | Examples: 30 | Run on single machine: 31 | $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml 32 | Change some config options: 33 | $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 34 | Run on multiple machines: 35 | (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags] 36 | (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags] 37 | """, 38 | formatter_class=argparse.RawDescriptionHelpFormatter, 39 | ) 40 | 41 | parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") 42 | 43 | return parser 44 | 45 | 46 | def _try_get_key(cfg, *keys, default=None): 47 | """ 48 | Try select keys from cfg until the first key that exists. Otherwise return default. 49 | """ 50 | if isinstance(cfg, CfgNode): 51 | cfg = OmegaConf.create(cfg.dump()) 52 | for k in keys: 53 | none = object() 54 | p = OmegaConf.select(cfg, k, default=none) 55 | if p is not none: 56 | return p 57 | return default 58 | 59 | def _highlight(code, filename): 60 | try: 61 | import pygments 62 | except ImportError: 63 | return code 64 | 65 | from pygments.lexers import Python3Lexer, YamlLexer 66 | from pygments.formatters import Terminal256Formatter 67 | 68 | lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer() 69 | code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai")) 70 | return code 71 | 72 | 73 | def time2file_name(time): 74 | year = time[0:4] 75 | month = time[5:7] 76 | day = time[8:10] 77 | hour = time[11:13] 78 | minute = time[14:16] 79 | second = time[17:19] 80 | time_filename = year + '_' + month + '_' + day + '_' + hour + '_' + minute + '_' + second 81 | return time_filename 82 | 83 | 84 | def gen_log(model_path): 85 | logger = logging.getLogger() 86 | logger.setLevel(logging.INFO) 87 | formatter = logging.Formatter("%(asctime)s - %(levelname)s: %(message)s") 88 | 89 | log_file = model_path + '/log.txt' 90 | fh = logging.FileHandler(log_file, mode='a') 91 | fh.setLevel(logging.INFO) 92 | fh.setFormatter(formatter) 93 | 94 | ch = logging.StreamHandler() 95 | ch.setLevel(logging.INFO) 96 | ch.setFormatter(formatter) 97 | 98 | logger.addHandler(fh) 99 | logger.addHandler(ch) 100 | return logger 101 | 102 | def default_setup(cfg, args): 103 | """ 104 | Perform some basic common setups at the beginning of a job, including: 105 | 1. Set up the detectron2 logger 106 | 2. Log basic information about environment, cmdline arguments, and config 107 | 3. Backup the config to the output directory 108 | Args: 109 | cfg (CfgNode or omegaconf.DictConfig): the full config to be used 110 | args (argparse.NameSpace): the command line arguments to be logged 111 | """ 112 | output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir") 113 | if output_dir: 114 | date_time = str(datetime.datetime.now()) 115 | date_time = time2file_name(date_time) 116 | output_dir = output_dir + date_time 117 | val_dir = os.path.join(output_dir, "val") 118 | test_dir = os.path.join(output_dir, "test") 119 | PathManager.mkdirs(output_dir) 120 | PathManager.mkdirs(val_dir) 121 | PathManager.mkdirs(test_dir) 122 | 123 | logger = gen_log(output_dir) 124 | writer = SummaryWriter(log_dir=output_dir) 125 | 126 | 127 | logger.info("Command line arguments: " + str(args)) 128 | if hasattr(args, "config_file") and args.config_file != "": 129 | logger.info( 130 | "Contents of args.config_file={}:\n{}".format( 131 | args.config_file, 132 | _highlight(PathManager.open(args.config_file, "r").read(), args.config_file), 133 | ) 134 | ) 135 | 136 | if output_dir: 137 | # Note: some of our scripts may expect the existence of 138 | # config.yaml in output directory 139 | path = os.path.join(output_dir, "config.yaml") 140 | if isinstance(cfg, CfgNode): 141 | logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml"))) 142 | with PathManager.open(path, "w") as f: 143 | f.write(cfg.dump()) 144 | logger.info("Full config saved to {}".format(path)) 145 | 146 | # make sure each worker has a different, yet deterministic seed if specified 147 | seed = _try_get_key(cfg, "SEED", "train.seed", default=3407) 148 | seed_everything(seed, deterministic=cfg.DETERMINISTIC) 149 | 150 | return logger, writer, output_dir 151 | 152 | 153 | def seed_everything( 154 | seed = 3407, 155 | deterministic = False, 156 | ): 157 | """Set random seed. 158 | Args: 159 | seed (int): Seed to be used, default seed 3407, from the paper 160 | Torch. manual_seed (3407) is all you need: On the influence of random seeds in deep learning architectures for computer vision[J]. arXiv preprint arXiv:2109.08203, 2021. 161 | deterministic (bool): Whether to set the deterministic option for 162 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 163 | to True and `torch.backends.cudnn.benchmark` to False. 164 | Default: False. 165 | rank_shift (bool): Whether to add rank number to the random seed to 166 | have different random seed in different threads. Default: False. 167 | """ 168 | random.seed(seed) 169 | np.random.seed(seed) 170 | torch.manual_seed(seed) 171 | torch.cuda.manual_seed(seed) 172 | torch.cuda.manual_seed_all(seed) 173 | os.environ['PYTHONHASHSEED'] = str(seed) 174 | if deterministic: 175 | torch.backends.cudnn.deterministic = True 176 | torch.backends.cudnn.benchmark = False 177 | -------------------------------------------------------------------------------- /Quality_Metrics/cal_ssim.m: -------------------------------------------------------------------------------- 1 | function ssim = cal_ssim( im1, im2, b_row, b_col ) 2 | 3 | [h w ch] = size( im1 ); 4 | ssim = 0; 5 | if ch==1 6 | ssim = ssim_index( im1(b_row+1:h-b_row, b_col+1:w-b_col), im2( b_row+1:h-b_row, b_col+1:w-b_col) ); 7 | else 8 | for i = 1:ch 9 | ssim = ssim + ssim_index( im1(b_row+1:h-b_row, b_col+1:w-b_col, i), im2( b_row+1:h-b_row, b_col+1:w-b_col, i) ); 10 | end 11 | ssim = ssim/ch; 12 | end 13 | return; 14 | 15 | 16 | 17 | 18 | function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L) 19 | 20 | %======================================================================== 21 | %SSIM Index, Version 1.0 22 | %Copyright(c) 2003 Zhou Wang 23 | %All Rights Reserved. 24 | % 25 | %The author was with Howard Hughes Medical Institute, and Laboratory 26 | %for Computational Vision at Center for Neural Science and Courant 27 | %Institute of Mathematical Sciences, New York University, USA. He is 28 | %currently with Department of Electrical and Computer Engineering, 29 | %University of Waterloo, Canada. 30 | % 31 | %---------------------------------------------------------------------- 32 | %Permission to use, copy, or modify this software and its documentation 33 | %for educational and research purposes only and without fee is hereby 34 | %granted, provided that this copyright notice and the original authors' 35 | %names appear on all copies and supporting documentation. This program 36 | %shall not be used, rewritten, or adapted as the basis of a commercial 37 | %software or hardware product without first obtaining permission of the 38 | %authors. The authors make no representations about the suitability of 39 | %this software for any purpose. It is provided "as is" without express 40 | %or implied warranty. 41 | %---------------------------------------------------------------------- 42 | % 43 | %This is an implementation of the algorithm for calculating the 44 | %Structural SIMilarity (SSIM) index between two images. Please refer 45 | %to the following paper: 46 | % 47 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 48 | %quality assessment: From error measurement to structural similarity" 49 | %IEEE Transactios on Image Processing, vol. 13, no. 4, Apr. 2004. 50 | % 51 | %Kindly report any suggestions or corrections to zhouwang@ieee.org 52 | % 53 | %---------------------------------------------------------------------- 54 | % 55 | %Input : (1) img1: the first image being compared 56 | % (2) img2: the second image being compared 57 | % (3) K: constants in the SSIM index formula (see the above 58 | % reference). defualt value: K = [0.01 0.03] 59 | % (4) window: local window for statistics (see the above 60 | % reference). default widnow is Gaussian given by 61 | % window = fspecial('gaussian', 11, 1.5); 62 | % (5) L: dynamic range of the images. default: L = 255 63 | % 64 | %Output: (1) mssim: the mean SSIM index value between 2 images. 65 | % If one of the images being compared is regarded as 66 | % perfect quality, then mssim can be considered as the 67 | % quality measure of the other image. 68 | % If img1 = img2, then mssim = 1. 69 | % (2) ssim_map: the SSIM index map of the test image. The map 70 | % has a smaller size than the input images. The actual size: 71 | % size(img1) - size(window) + 1. 72 | % 73 | %Default Usage: 74 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 75 | % 76 | % [mssim ssim_map] = ssim_index(img1, img2); 77 | % 78 | %Advanced Usage: 79 | % User defined parameters. For example 80 | % 81 | % K = [0.05 0.05]; 82 | % window = ones(8); 83 | % L = 100; 84 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L); 85 | % 86 | %See the results: 87 | % 88 | % mssim %Gives the mssim value 89 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 90 | % 91 | %======================================================================== 92 | 93 | 94 | if (nargin < 2 | nargin > 5) 95 | mssim = -Inf; 96 | ssim_map = -Inf; 97 | return; 98 | end 99 | 100 | if (size(img1) ~= size(img2)) 101 | mssim = -Inf; 102 | ssim_map = -Inf; 103 | return; 104 | end 105 | 106 | [M N] = size(img1); 107 | 108 | if (nargin == 2) 109 | if ((M < 11) | (N < 11)) 110 | mssim = -Inf; 111 | ssim_map = -Inf; 112 | return 113 | end 114 | window = fspecial('gaussian', 11, 1.5); % 115 | K(1) = 0.01; % default settings 116 | K(2) = 0.03; % 117 | L = 255; % 118 | end 119 | 120 | if (nargin == 3) 121 | if ((M < 11) | (N < 11)) 122 | mssim = -Inf; 123 | ssim_map = -Inf; 124 | return 125 | end 126 | window = fspecial('gaussian', 11, 1.5); 127 | L = 255; 128 | if (length(K) == 2) 129 | if (K(1) < 0 | K(2) < 0) 130 | mssim = -Inf; 131 | ssim_map = -Inf; 132 | return; 133 | end 134 | else 135 | mssim = -Inf; 136 | ssim_map = -Inf; 137 | return; 138 | end 139 | end 140 | 141 | if (nargin == 4) 142 | [H W] = size(window); 143 | if ((H*W) < 4 | (H > M) | (W > N)) 144 | mssim = -Inf; 145 | ssim_map = -Inf; 146 | return 147 | end 148 | L = 255; 149 | if (length(K) == 2) 150 | if (K(1) < 0 | K(2) < 0) 151 | mssim = -Inf; 152 | ssim_map = -Inf; 153 | return; 154 | end 155 | else 156 | mssim = -Inf; 157 | ssim_map = -Inf; 158 | return; 159 | end 160 | end 161 | 162 | if (nargin == 5) 163 | [H W] = size(window); 164 | if ((H*W) < 4 | (H > M) | (W > N)) 165 | mssim = -Inf; 166 | ssim_map = -Inf; 167 | return 168 | end 169 | if (length(K) == 2) 170 | if (K(1) < 0 | K(2) < 0) 171 | mssim = -Inf; 172 | ssim_map = -Inf; 173 | return; 174 | end 175 | else 176 | mssim = -Inf; 177 | ssim_map = -Inf; 178 | return; 179 | end 180 | end 181 | 182 | C1 = (K(1)*L)^2; 183 | C2 = (K(2)*L)^2; 184 | window = window/sum(sum(window)); 185 | img1 = double(img1); 186 | img2 = double(img2); 187 | 188 | mu1 = filter2(window, img1, 'valid'); 189 | mu2 = filter2(window, img2, 'valid'); 190 | mu1_sq = mu1.*mu1; 191 | mu2_sq = mu2.*mu2; 192 | mu1_mu2 = mu1.*mu2; 193 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 194 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 195 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 196 | 197 | if (C1 > 0 & C2 > 0) 198 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 199 | else 200 | numerator1 = 2*mu1_mu2 + C1; 201 | numerator2 = 2*sigma12 + C2; 202 | denominator1 = mu1_sq + mu2_sq + C1; 203 | denominator2 = sigma1_sq + sigma2_sq + C2; 204 | ssim_map = ones(size(mu1)); 205 | index = (denominator1.*denominator2 > 0); 206 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 207 | index = (denominator1 ~= 0) & (denominator2 == 0); 208 | ssim_map(index) = numerator1(index)./denominator1(index); 209 | end 210 | 211 | mssim = mean2(ssim_map); 212 | 213 | return -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DERNN-LNLT for CASSI 2 | 3 | This repo is the implementation of paper "Degradation Estimation Recurrent Neural Network with Local and Non-Local Priors for Compressive Spectral Imaging" 4 | 5 | # Abstract 6 | 7 | In the Coded Aperture Snapshot Spectral Imaging (CASSI) system, deep unfolding networks (DUNs) have demonstrated excellent performance in recovering 3D hyperspectral images (HSI) from 2D measurements. However, some noticeable gaps exist between the imaging model used in DUNs and the real CASSI imaging process, such as the sensing error as well as photon and dark current noise, compromising the accuracy of solving the data subproblem and the prior subproblem in DUNs. To address this issue, we propose a Degradation Estimation Network (DEN) to correct the imaging model used in DUNs by simultaneously estimating the sensing error and the noise level, thereby improving the performance of DUNs. Additionally, we propose an efficient Local and Non-local Transformer (LNLT) to solve the prior subproblem, which not only effectively models local and non-local similarities but also reduces the computational cost of the window-based global Multi-head Self-attention (MSA). Furthermore, we transform the DUN into a Recurrent Neural Network (RNN) by sharing parameters of DNNs across stages, which not only allows DNN to be trained more adequately but also significantly reduces the number of parameters. The proposed DERNN-LNLT achieves state-of-the-art (SOTA) performance with fewer parameters on both simulation and real datasets. 8 | 9 | 10 | # Comparison with other Deep Unfolding Networks 11 | 12 |
13 | 14 |
15 | 16 | Comparison of PSNR-Params with previous HSI DUNs. The PSNR (in dB) is plotted on the vertical axis, while the number of parameters is represented on the horizontal axis. The proposed DERNN-LNLT outperforms the previous DUNs while requiring much fewer parameters. 17 | 18 | # Architecture 19 | 20 |
21 | 22 |
23 | 24 | The DERNN-LNLT alternatively solves a data problem and a prior subproblem in each recurrent step. Firstly, the DERNN-LNLT unfolds the HQS algorithm within the MAP framework and transfors the DUN into an RNN by sharing parameters across stages. Then, the DERNN-LNLT integrate the Degradation Estimation Network into the RNN, which estimates the degradation matrix for the data subproblem and the noise level for the prior subproblem by residual learning with reference to the sensing matrix. Subsequently, the Local and Non-Local Transformer (LNLT) utilizes the Local and Non-Local Multi-head Self-Attention (MSA) to effectively exploit both local and non-local HSIs priors. Finally, incorporating the LNLT into the DERNN as the denoiser for the prior subproblem leads to the proposed DERNN-LNLT. 25 | 26 | # Usage 27 | 28 | ## Prepare Dataset: 29 | 30 | Download cave_1024_28 ([Baidu Disk](https://pan.baidu.com/s/1X_uXxgyO-mslnCTn4ioyNQ), code: `fo0q` | [One Drive](https://bupteducn-my.sharepoint.com/:f:/g/personal/mengziyi_bupt_edu_cn/EmNAsycFKNNNgHfV9Kib4osB7OD4OSu-Gu6Qnyy5PweG0A?e=5NrM6S)), CAVE_512_28 ([Baidu Disk](https://pan.baidu.com/s/1ue26weBAbn61a7hyT9CDkg), code: `ixoe` | [One Drive](https://mailstsinghuaeducn-my.sharepoint.com/:f:/g/personal/lin-j21_mails_tsinghua_edu_cn/EjhS1U_F7I1PjjjtjKNtUF8BJdsqZ6BSMag_grUfzsTABA?e=sOpwm4)), KAIST_CVPR2021 ([Baidu Disk](https://pan.baidu.com/s/1LfPqGe0R_tuQjCXC_fALZA), code: `5mmn` | [One Drive](https://mailstsinghuaeducn-my.sharepoint.com/:f:/g/personal/lin-j21_mails_tsinghua_edu_cn/EkA4B4GU8AdDu0ZkKXdewPwBd64adYGsMPB8PNCuYnpGlA?e=VFb3xP)), TSA_simu_data ([Baidu Disk](https://pan.baidu.com/s/1LI9tMaSprtxT8PiAG1oETA), code: `efu8` | [One Drive](https://1drv.ms/u/s!Au_cHqZBKiu2gYFDwE-7z1fzeWCRDA?e=ofvwrD)), TSA_real_data ([Baidu Disk](https://pan.baidu.com/s/1RoOb1CKsUPFu0r01tRi5Bg), code: `eaqe` | [One Drive](https://1drv.ms/u/s!Au_cHqZBKiu2gYFTpCwLdTi_eSw6ww?e=uiEToT)), and then put them into the corresponding folders of `datasets/` and recollect them as the following form: 31 | 32 | 33 | ``` 34 | |--DERNN_LNLT 35 | |--datasets 36 | |--CSI 37 | |--cave_1024_28 38 | |--scene1.mat 39 | |--scene2.mat 40 | : 41 | |--scene205.mat 42 | |--CAVE_512_28 43 | |--scene1.mat 44 | |--scene2.mat 45 | : 46 | |--scene30.mat 47 | |--KAIST_CVPR2021 48 | |--1.mat 49 | |--2.mat 50 | : 51 | |--30.mat 52 | |--TSA_simu_data 53 | |--mask_3d_shift.mat 54 | |--mask.mat 55 | |--Truth 56 | |--scene01.mat 57 | |--scene02.mat 58 | : 59 | |--scene10.mat 60 | |--TSA_real_data 61 | |--mask_3d_shift.mat 62 | |--mask.mat 63 | |--Measurements 64 | |--scene1.mat 65 | |--scene2.mat 66 | : 67 | |--scene5.mat 68 | |--checkpoints 69 | |--csi 70 | |--scripts 71 | |--tools 72 | |--results 73 | |--Quality_Metrics 74 | |--visualization 75 | ``` 76 | 77 | We use the CAVE dataset (cave_1024_28) as the simulation training set. Both the CAVE (cave_1024_28) and KAIST (KAIST_CVPR2021) datasets are used as the real training set. 78 | 79 | ## Pretrained weights 80 | 81 | Download pretrained weights ([Baidu Disk](https://pan.baidu.com/s/1BBQbFnYXx-glqYkrZ9DCSQ), code: `lnlt` | [Google Drive](https://drive.google.com/drive/folders/1aVGHRcHB2svBoYkM3pPCbZO765EXpLd8?usp=sharing)) and put them into `DERNN_LNLT/checkpoints/` 82 | 83 | ## Simulation Experiement: 84 | 85 | ### Training 86 | 87 | ``` 88 | cd DERNN_LNLT/ 89 | 90 | # DERNN-LNLT 5stage 91 | bash ./scripts/train_dernn_lnlt_5stg_simu.sh 92 | 93 | # DERNN-LNLT 7stage 94 | bash ./scripts/train_dernn_lnlt_7stg_simu.sh 95 | 96 | # DERNN-LNLT 9stage 97 | bash ./scripts/train_dernn_lnlt_9stg_simu.sh 98 | 99 | # DERNN-LNLT 9stage* 100 | bash ./scripts/train_dernn_lnlt_9stg_star_simu.sh 101 | ``` 102 | 103 | The training log, trained model, and reconstrcuted HSI will be available in `DERNN_LNLT/exp/` . 104 | 105 | ### Testing 106 | 107 | Place the pretrained model to `DERNN_LNLT/checkpoints/` 108 | 109 | Run the following command to test the model on the simulation dataset. 110 | 111 | ``` 112 | cd DERNN_LNLT/ 113 | 114 | # DERNN_LNLT 5stage 115 | bash ./scripts/test_dernn_lnlt_5stg_simu.sh 116 | 117 | # DERNN_LNLT 7stage 118 | bash ./scripts/test_dernn_lnlt_7stg_simu.sh 119 | 120 | # DERNN_LNLT 9stage 121 | bash ./scripts/test_dernn_lnlt_9stg_simu.sh 122 | 123 | # DERNN_LNLT 9stage* 124 | bash ./scripts/test_dernn_lnlt_9stg_star_simu.sh 125 | ``` 126 | 127 | The reconstrcuted HSIs will be output into `DERNN_LNLT/results/` 128 | 129 | ``` 130 | Run cal_quality_assessment.m 131 | ``` 132 | 133 | to calculate the PSNR and SSIM of the reconstructed HSIs. 134 | 135 | 136 | ### Visualization 137 | 138 | - Put the reconstruted HSI in `DERNN_LNLT/visualization/simulation_results/results` and rename it as method.mat, e.g., DERNN_LNLT_9stg_simu.mat 139 | - Generate the RGB images of the reconstructed HSIs 140 | 141 | ``` 142 | cd DERNN_LNLT/visualization/ 143 | Run show_simulation.m 144 | ``` 145 | 146 | 147 | ## Real Experiement: 148 | 149 | ### Training 150 | 151 | ``` 152 | cd DERNN_LNLT/ 153 | 154 | # DERNN-LNLT 5stage 155 | bash ./scripts/train_dernn_lnlt_5stg_real.sh 156 | ``` 157 | 158 | The training log and trained model will be available in `DERNN_LNLT/exp/` 159 | 160 | ### Testing 161 | 162 | ``` 163 | cd DERNN_LNLT/ 164 | 165 | # DERNN-LNLT 5stage 166 | bash ./scripts/test_dernn_lnlt_5stg_real.sh 167 | ``` 168 | 169 | The reconstrcuted HSI will be output into `DERNN_LNLT/results/` 170 | 171 | ### Visualization 172 | 173 | Generate the RGB images of the reconstructed HSI 174 | 175 | ``` 176 | cd DERNN_LNLT/visualization/ 177 | Run show_real.m 178 | ``` 179 | 180 | ## Acknowledgements 181 | 182 | Our code is based on following codes, thanks for their generous open source: 183 | 184 | - [https://github.com/ShawnDong98/RDLUF_MixS2](https://github.com/ShawnDong98/RDLUF_MixS2) 185 | - [https://github.com/caiyuanhao1998/MST](https://github.com/caiyuanhao1998/MST) 186 | - [https://github.com/TaoHuang95/DGSMP](https://github.com/TaoHuang95/DGSMP) 187 | - [https://github.com/mengziyi64/TSA-Net](https://github.com/mengziyi64/TSA-Net) 188 | - [https://github.com/facebookresearch/detectron2](https://github.com/facebookresearch/detectron2) 189 | 190 | 191 | ## Citation 192 | 193 | If this code helps you, please consider citing our works: 194 | 195 | ```shell 196 | @article{dernn_lnlt, 197 | title={Degradation Estimation Recurrent Neural Network with Local and Non-Local Priors for Compressive Spectral Imaging}, 198 | author={Dong, Yubo and Gao, Dahua and Li, Yuyan and Shi, Guangming and Liu, Danhua}, 199 | journal={arXiv preprint arXiv:2311.08808}, 200 | year={2023} 201 | } 202 | 203 | @inproceedings{rdluf_mixs2, 204 | title={Residual Degradation Learning Unfolding Framework with Mixing Priors across Spectral and Spatial for Compressive Spectral Imaging}, 205 | author={Dong, Yubo and Gao, Dahua and Qiu, Tian and Li, Yuyan and Yang, Minxi and Shi, Guangming}, 206 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 207 | pages={22262--22271}, 208 | year={2023} 209 | } 210 | ``` -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | # add python path of PadleDetection to sys.path 6 | parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) 7 | sys.path.insert(0, parent_path) 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from torch import optim 13 | from torch.utils.data import DataLoader 14 | from torch.nn.utils import clip_grad_norm_ 15 | from torchvision.utils import make_grid 16 | from torch_ema import ExponentialMovingAverage 17 | 18 | import cv2 19 | import numpy as np 20 | from scipy import io as sio 21 | from tqdm import tqdm 22 | 23 | from csi.config import get_cfg 24 | from csi.engine import default_argument_parser, default_setup 25 | from csi.data import CSITrainDataset, LoadVal, LoadTSATestMeas, shift_back_batch, generate_mask_3d, generate_mask_3d_shift, gen_meas_torch_batch 26 | from csi.architectures import DERNN_LNLT 27 | from csi.utils.schedulers import get_cosine_schedule_with_warmup 28 | from csi.losses import CharbonnierLoss, TVLoss 29 | from csi.metrics import torch_psnr, torch_ssim, sam 30 | from csi.utils.utils import checkpoint 31 | 32 | args = default_argument_parser().parse_args() 33 | cfg = get_cfg() 34 | cfg.merge_from_file(args.config_file) 35 | cfg.freeze() 36 | logger, writer, output_dir = default_setup(cfg, args) 37 | 38 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 39 | 40 | # init mask 41 | mask = generate_mask_3d_shift(mask_path=cfg.DATASETS.VAL.MASK_PATH).to(device) 42 | mask_test = generate_mask_3d_shift(mask_path=cfg.DATASETS.TEST.MASK_PATH).to(device) 43 | 44 | val_datas = LoadVal(cfg.DATASETS.VAL.PATH) 45 | 46 | 47 | test_meas = LoadTSATestMeas(cfg.DATASETS.TEST.PATH).to(device) 48 | 49 | 50 | model = eval(cfg.MODEL.DENOISER.TYPE)(cfg).to(device) 51 | 52 | 53 | ema = ExponentialMovingAverage(model.parameters(), decay=cfg.MODEL.EMA.DECAY) 54 | 55 | # optimizing 56 | optimizer = optim.Adam(model.parameters(), lr=cfg.OPTIMIZER.LR, betas=(0.9, 0.999)) 57 | 58 | scheduler = get_cosine_schedule_with_warmup( 59 | optimizer, 60 | num_warmup_steps=int(np.floor(cfg.DATASETS.TRAIN.ITERATION / cfg.DATALOADER.BATCH_SIZE)), 61 | num_training_steps=int(np.floor(cfg.DATASETS.TRAIN.ITERATION / cfg.DATALOADER.BATCH_SIZE)) * cfg.OPTIMIZER.MAX_EPOCH, 62 | eta_min=1e-6) 63 | 64 | if cfg.LOSSES.L1_LOSS: l1_loss = CharbonnierLoss().to(device) 65 | if cfg.LOSSES.TV_LOSS: tv_loss = TVLoss().to(device) 66 | 67 | start_epoch = 0 68 | 69 | if cfg.RESUME_CKPT_PATH: 70 | print(f"===> Loading Checkpoint from {cfg.RESUME_CKPT_PATH}") 71 | save_state = torch.load(cfg.RESUME_CKPT_PATH) 72 | model.load_state_dict(save_state['model']) 73 | ema.load_state_dict(save_state['ema']) 74 | optimizer.load_state_dict(save_state['optimizer']) 75 | scheduler.load_state_dict(save_state['scheduler']) 76 | start_epoch = save_state['epoch'] 77 | 78 | 79 | def train(epoch, train_loader): 80 | model.train() 81 | epoch_loss = 0 82 | begin = time.time() 83 | batch_num = int(np.floor(cfg.DATASETS.TRAIN.ITERATION / train_loader.batch_size)) 84 | train_tqdm = tqdm(range(batch_num)[:5]) if cfg.DEBUG else tqdm(range(batch_num)) 85 | 86 | loss_dict = {} 87 | for i in train_tqdm: 88 | data_time = time.time() 89 | try: 90 | data = next(data_iter) 91 | except: 92 | data_iter = iter(train_loader) 93 | data = next(data_iter) 94 | 95 | data = {k:v.to(device) for k, v in data.items()} 96 | 97 | data_time = time.time() - data_time 98 | 99 | model_time = time.time() 100 | # model_out = model(meas_batch) 101 | model_out = model(data) 102 | model_time = time.time() - model_time 103 | 104 | loss = 0 105 | if cfg.LOSSES.L1_LOSS: 106 | loss_l1 = l1_loss(model_out, data['hsi']) 107 | loss_dict['loss_l1'] = f"{loss_l1.item():.4f}" 108 | loss += loss_l1 109 | if cfg.LOSSES.TV_LOSS: 110 | loss_tv = tv_loss(model_out) 111 | loss_dict['loss_tv'] = f"{loss_tv.item():.4f}" 112 | loss += loss_tv 113 | 114 | loss.backward() 115 | if cfg.OPTIMIZER.GRAD_CLIP: 116 | clip_grad_norm_(model.parameters(), max_norm=0.2) 117 | 118 | optimizer.step() 119 | optimizer.zero_grad() 120 | ema.update() 121 | loss_dict['data_time'] = data_time 122 | loss_dict['model_time'] = model_time 123 | train_tqdm.set_postfix(loss_dict) 124 | epoch_loss += loss.data 125 | writer.add_scalar('LR/train',optimizer.state_dict()['param_groups'][0]['lr'], epoch * batch_num + i) 126 | scheduler.step() 127 | end = time.time() 128 | train_loss = epoch_loss / batch_num 129 | logger.info("===> Epoch {} Complete: Avg. Loss: {:.6f} time: {:.2f}". 130 | format(epoch, train_loss, (end - begin))) 131 | return train_loss 132 | 133 | 134 | def eval(epoch): 135 | psnr_list, ssim_list, sam_list = [], [], [] 136 | # val_H = [] 137 | # val_Y = [] 138 | # val_gt = [] 139 | # for val_label in val_datas['hsi']: 140 | # val_label = torch.from_numpy(val_label).permute(2, 0, 1).to(device).float() 141 | # YH = gen_meas_torch(val_label, mask, step=cfg.DATASETS.STEP, wave_len=cfg.DATASETS.WAVE_LENS, mask_type=cfg.DATASETS.MASK_TYPE) 142 | # val_H.append(YH['H'].to(device)) 143 | # val_Y.append(YH['Y'].to(device)) 144 | # val_gt.append(val_label) 145 | # val_gt = torch.stack(val_gt) 146 | # val_H = torch.stack(val_H) 147 | # val_Y = torch.stack(val_Y) 148 | 149 | val_gt = torch.stack([torch.from_numpy(val_label).permute(2, 0, 1).to(device).float() for val_label in val_datas['hsi']]) 150 | B, _, _, _ = val_gt.shape 151 | val_mask = mask.unsqueeze(0).tile((B, 1, 1, 1)) 152 | YH = gen_meas_torch_batch(val_gt, val_mask, step=cfg.DATASETS.STEP, wave_len=cfg.DATASETS.WAVE_LENS, mask_type=cfg.DATASETS.MASK_TYPE, with_noise=cfg.DATASETS.TRAIN.WITH_NOISE) 153 | 154 | data = {} 155 | data['hsi'] = val_gt 156 | data['H'] = YH['H'] 157 | data['mask'] = val_mask 158 | data['Y'] = YH['Y'] 159 | 160 | model.eval() 161 | begin = time.time() 162 | with torch.no_grad(): 163 | with ema.average_parameters(): 164 | out = model(data) 165 | model_out = out 166 | 167 | for i in range(len(model_out)): 168 | psnr_val = torch_psnr(model_out[i, :, :, :], val_gt[i, :, :, :]) 169 | ssim_val = torch_ssim(model_out[i, :, :, :], val_gt[i, :, :, :]) 170 | sam_val = sam(model_out[i, :, :, :].permute(1, 2, 0).cpu().numpy(), val_gt[i, :, :, :].permute(1, 2, 0).cpu().numpy()) 171 | psnr_list.append(psnr_val.detach().cpu().numpy()) 172 | ssim_list.append(ssim_val.detach().cpu().numpy()) 173 | sam_list.append(sam_val) 174 | 175 | pred = np.transpose(model_out.detach().cpu().numpy(), (0, 2, 3, 1)).astype(np.float32) 176 | truth = np.transpose(val_gt.cpu().numpy(), (0, 2, 3, 1)).astype(np.float32) 177 | psnr_mean = np.mean(np.asarray(psnr_list)) 178 | ssim_mean = np.mean(np.asarray(ssim_list)) 179 | sam_mean = np.mean(np.asarray(sam_list)) 180 | 181 | model_out = F.interpolate(model_out, size=(128, 128)) 182 | for i, out in enumerate(model_out): 183 | out_grid = make_grid(out.unsqueeze(1).clip(0, 1), nrow=7) 184 | writer.add_image(f'images/val_scene{i}', out_grid, epoch) 185 | end = time.time() 186 | 187 | logger.info('===> Epoch {}: testing psnr = {:.2f}, ssim = {:.3f}, sam = {:.3f}, time: {:.2f}' 188 | .format(epoch, psnr_mean, ssim_mean, sam_mean, (end - begin))) 189 | model.train() 190 | return pred, truth, psnr_list, ssim_list, sam_list, psnr_mean, ssim_mean, sam_mean 191 | 192 | def test(epoch, test_meas, name="test_a"): 193 | model.eval() 194 | model_out = [] 195 | data = {} 196 | data['Y'] = test_meas / test_meas.max() * 0.8 197 | # data['Y'] = test_meas / (test_meas.max() + 1e-7) * 0.9 198 | B, _, _ = test_meas.shape 199 | data['mask'] = mask_test.unsqueeze(0).tile((B, 1, 1, 1)) 200 | data['H'] = shift_back_batch(test_meas, step=cfg.DATASETS.STEP, nC=cfg.DATASETS.WAVE_LENS) 201 | 202 | with torch.no_grad(): 203 | with ema.average_parameters(): 204 | model_out = model(data) 205 | 206 | 207 | for i in range(B): 208 | out_plot = F.interpolate(model_out[i:i+1, :, :, :], size=(128, 128)) 209 | if name == "TSA": out_plot = torch.flip(out_plot, dims=(2, 3)) 210 | grid = make_grid(out_plot.permute(1, 0, 2, 3).clip(0, 1), nrow=7) 211 | writer.add_image('images/' + name + f'_scene{i}', grid, epoch) 212 | 213 | model_out = np.transpose(model_out.detach().cpu().numpy(), (0, 2, 3, 1)).astype(np.float32) 214 | model.train() 215 | 216 | return model_out 217 | 218 | 219 | 220 | def main(): 221 | psnr_max = 0 222 | sam_min = 9999 223 | dataset = CSITrainDataset(cfg, crop_size=cfg.DATASETS.TRAIN.CROP_SIZE) 224 | for epoch in range(start_epoch+1, cfg.OPTIMIZER.MAX_EPOCH): 225 | train_loader = DataLoader( 226 | dataset = dataset, 227 | batch_size = cfg.DATALOADER.BATCH_SIZE, 228 | shuffle = True, 229 | num_workers = cfg.DATALOADER.NUM_WORKERS, 230 | pin_memory = False, 231 | drop_last = True 232 | ) 233 | train_loss = train(epoch, train_loader) 234 | torch.cuda.empty_cache() 235 | (pred, truth, psnr_all, ssim_all, sam_all, psnr_mean, ssim_mean, sam_mean) = eval(epoch) 236 | test_out = test(epoch, test_meas, "TSA") 237 | 238 | if cfg.DATASETS.TRAIN.WITH_NOISE: 239 | checkpoint(model, ema, optimizer, scheduler, epoch, output_dir, logger) 240 | sio.savemat(os.path.join(output_dir, "val", f"epoch{epoch}_SAM{sam_mean}.mat"), {"pred": pred, "truth": truth}) 241 | sio.savemat(os.path.join(output_dir, "test", f"test_epoch{epoch}_SAM{sam_mean}.mat"), {"pred": test_out}) 242 | else: 243 | if sam_mean < sam_min: 244 | sam_min = sam_mean 245 | checkpoint(model, ema, optimizer, scheduler, epoch, output_dir, logger) 246 | sio.savemat(os.path.join(output_dir, "val", f"epoch{epoch}_SAM{sam_mean}.mat"), {"pred": pred, "truth": truth}) 247 | sio.savemat(os.path.join(output_dir, "test", f"test_epoch{epoch}_SAM{sam_mean}.mat"), {"pred": test_out}) 248 | 249 | 250 | 251 | 252 | if __name__ == '__main__': 253 | main() -------------------------------------------------------------------------------- /csi/data/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from copy import deepcopy 4 | 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | import cv2 9 | import numpy as np 10 | from scipy import io as sio 11 | 12 | from glob import glob 13 | 14 | from box import Box 15 | 16 | 17 | def generate_mask_3d_shift(mask_path): 18 | mask = sio.loadmat(mask_path) 19 | mask3d_shift = mask['mask_3d_shift'] 20 | mask3d_shift = np.transpose(mask3d_shift, [2, 0, 1]) 21 | mask3d_shift = torch.from_numpy(mask3d_shift).to(torch.float32) 22 | 23 | return mask3d_shift 24 | 25 | def generate_mask_3d(mask_path, wave_len): 26 | mask = sio.loadmat(mask_path) 27 | mask = mask['mask'] 28 | mask3d = np.tile(mask[:, :, np.newaxis], (1, 1, wave_len)) 29 | mask3d = np.transpose(mask3d, [2, 0, 1]) 30 | mask3d = torch.from_numpy(mask3d).to(torch.float32) 31 | 32 | return mask3d 33 | 34 | def LoadTraining(paths, debug=False): 35 | imgs = [] 36 | scene_list = [] 37 | for path in paths: 38 | scene_list.extend(glob(os.path.join(path, "*"))) 39 | scene_list.sort() 40 | print('training sences:', len(scene_list)) 41 | for scene_path in scene_list if not debug else scene_list[:20]: 42 | img_dict = sio.loadmat(scene_path) 43 | if "img_expand" in img_dict: 44 | img = img_dict['img_expand'] / 65536. 45 | elif "img" in img_dict: 46 | img = img_dict['img'] / 65536. 47 | elif "hsi" in img_dict: 48 | img = img_dict['hsi'] 49 | elif "HSI" in img_dict: 50 | img = img_dict['HSI'] 51 | img = img.astype(np.float32) 52 | imgs.append(img) 53 | print('Sence {} is loaded.'.format(scene_path.split('/')[-1])) 54 | return imgs 55 | 56 | 57 | def LoadVal(path_val): 58 | images = [] 59 | data = {} 60 | scene_list = os.listdir(path_val) 61 | scene_list.sort() 62 | for i in range(len(scene_list)): 63 | scene_path = os.path.join(path_val, scene_list[i]) 64 | img = sio.loadmat(scene_path) 65 | if "img_expand" in img: 66 | img = img['img_expand'] 67 | elif "img" in img: 68 | img = img['img'] 69 | elif "hsi" in img: 70 | img = img['hsi'] 71 | images.append(img) 72 | 73 | data["hsi"] = images 74 | return data 75 | 76 | def LoadTSATestMeas(path_test): 77 | measurements = [] 78 | meas_paths = sorted(glob(os.path.join(path_test, '*.mat'))) 79 | for meas_path in meas_paths: 80 | meas = sio.loadmat(meas_path)['meas_real'] 81 | meas[meas < 0] = 0.0 82 | meas[meas > 1] = 1.0 83 | measurements.append(meas) 84 | 85 | measurements = torch.from_numpy(np.stack(measurements, axis=0)).to(torch.float32) 86 | 87 | return measurements 88 | 89 | 90 | class CSITrainDataset(Dataset): 91 | def __init__( 92 | self, 93 | cfg, 94 | crop_size=(256, 256) 95 | ): 96 | super().__init__() 97 | self.cfg = cfg 98 | self.iteration = cfg.DATASETS.TRAIN.ITERATION 99 | self.crop_size = crop_size 100 | self.augment = cfg.DATASETS.TRAIN.AUGMENT 101 | self.imgs = LoadTraining(cfg.DATASETS.TRAIN.PATHS, cfg.DEBUG) 102 | if cfg.DATASETS.MASK_TYPE == "mask_3d": 103 | self.mask = generate_mask_3d(cfg.DATASETS.TRAIN.MASK_PATH, cfg.DATASETS.WAVE_LENS) 104 | if cfg.DATASETS.MASK_TYPE == "mask_3d_shift": 105 | self.mask = generate_mask_3d_shift(cfg.DATASETS.TRAIN.MASK_PATH) 106 | _, self.mask_h, self.mask_w = self.mask.shape 107 | 108 | self.len_images = len(self.imgs) 109 | 110 | def __getitem__(self, idx): 111 | data = {} 112 | if self.augment: 113 | flag = random.randint(0, 1) 114 | if flag: 115 | index = np.random.randint(0, self.len_images-1) 116 | img = self.imgs[index] 117 | processed_image = np.zeros((self.crop_size[0], self.crop_size[1], self.cfg.DATASETS.WAVE_LENS), dtype=np.float32) 118 | 119 | h, w, _ = img.shape 120 | if h > w: 121 | img = np.transpose(img, (1, 0, 2)) 122 | h, w = w, h 123 | 124 | x_index = np.random.randint(0, h - self.crop_size[0]) 125 | y_index = np.random.randint(0, w - self.crop_size[1]) 126 | processed_image = img[x_index:x_index + self.crop_size[0], y_index:y_index + self.crop_size[1], :] 127 | 128 | processed_image = torch.from_numpy(np.transpose(processed_image, (2, 0, 1))) 129 | 130 | 131 | processed_image = augment_1(processed_image) 132 | else: 133 | processed_image = np.zeros((4, self.crop_size[0]//2, self.crop_size[1]//2, self.cfg.DATASETS.WAVE_LENS), dtype=np.float32) 134 | sample_list = np.random.randint(0, self.len_images, 4) 135 | for j in range(4): 136 | h, w, _ = self.imgs[sample_list[j]].shape 137 | x_index = np.random.randint(0, h-self.crop_size[0]//2) 138 | y_index = np.random.randint(0, w-self.crop_size[1]//2) 139 | processed_image[j] = self.imgs[sample_list[j]][x_index:x_index+self.crop_size[0]//2,y_index:y_index+self.crop_size[1]//2,:] 140 | 141 | processed_image = torch.from_numpy(np.transpose(processed_image, (0, 3, 1, 2))) # [4,28,128,128] 142 | processed_image = augment_2(processed_image, self.crop_size) 143 | 144 | if self.cfg.DATASETS.TRAIN.RANDOM_MASK: 145 | mask_x_index = np.random.randint(0, self.mask_h - self.crop_size[0]) 146 | if self.cfg.DATASETS.MASK_TYPE == "mask_3d": 147 | mask_y_index = np.random.randint(0, self.mask_w - self.crop_size[1]) 148 | mask = self.mask[:, mask_x_index:mask_x_index + self.crop_size[0], mask_y_index:mask_y_index + self.crop_size[1]] 149 | else: 150 | mask_y_index = np.random.randint(0, self.mask_w - (self.crop_size[1] + (self.cfg.DATASETS.WAVE_LENS - 1) * self.cfg.DATASETS.STEP)) 151 | mask = self.mask[:, mask_x_index:mask_x_index + self.crop_size[0], mask_y_index:mask_y_index + (self.crop_size[1] + (self.cfg.DATASETS.WAVE_LENS - 1) * self.cfg.DATASETS.STEP)] 152 | else: 153 | mask = self.mask 154 | 155 | 156 | 157 | data['hsi'] = processed_image 158 | data['mask'] = mask 159 | 160 | return data 161 | 162 | def __len__(self): 163 | return self.iteration 164 | 165 | 166 | 167 | def shuffle_crop(train_data, batch_size, crop_size=256, augment=True): 168 | if augment: 169 | flag = random.randint(0, 1) 170 | if flag: 171 | index = np.random.choice(range(len(train_data)), batch_size) 172 | processed_data = np.zeros((batch_size, crop_size, crop_size, 28), dtype=np.float32) 173 | for i in range(batch_size): 174 | h, w, _ = train_data[index[i]].shape 175 | x_index = np.random.randint(0, h - crop_size) 176 | y_index = np.random.randint(0, w - crop_size) 177 | processed_data[i, :, :, :] = train_data[index[i]][x_index:x_index + crop_size, y_index:y_index + crop_size, :] 178 | gt_batch = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))) 179 | for i in range(gt_batch.shape[0]): 180 | gt_batch[i] = augment_1(gt_batch[i]) 181 | else: 182 | gt_batch = [] 183 | processed_data = np.zeros((4, 128, 128, 28), dtype=np.float32) 184 | for i in range(batch_size): 185 | sample_list = np.random.randint(0, len(train_data), 4) 186 | for j in range(4): 187 | h, w, _ = train_data[sample_list[j]].shape 188 | x_index = np.random.randint(0, h-crop_size//2) 189 | y_index = np.random.randint(0, w-crop_size//2) 190 | processed_data[j] = train_data[sample_list[j]][x_index:x_index+crop_size//2,y_index:y_index+crop_size//2,:] 191 | generated_sample = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))) # [4,28,128,128] 192 | gt_batch.append(augment_2(generated_sample, crop_size=(crop_size, crop_size))) 193 | gt_batch = torch.stack(gt_batch, dim=0) 194 | return gt_batch 195 | else: 196 | index = np.random.choice(range(len(train_data)), batch_size) 197 | processed_data = np.zeros((batch_size, crop_size, crop_size, 28), dtype=np.float32) 198 | for i in range(batch_size): 199 | h, w, _ = train_data[index[i]].shape 200 | x_index = np.random.randint(0, h - crop_size) 201 | y_index = np.random.randint(0, w - crop_size) 202 | processed_data[i, :, :, :] = train_data[index[i]][x_index:x_index + crop_size, y_index:y_index + crop_size, :] 203 | gt_batch = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))) 204 | 205 | return gt_batch 206 | 207 | 208 | 209 | def augment_1(x): 210 | """ 211 | :param x: c,h,w 212 | :return: c,h,w 213 | """ 214 | rotTimes = random.randint(0, 3) 215 | vFlip = random.randint(0, 1) 216 | hFlip = random.randint(0, 1) 217 | # Random rotation 218 | for j in range(rotTimes): 219 | x = torch.rot90(x, dims=(1, 2)) 220 | # Random vertical Flip 221 | for j in range(vFlip): 222 | x = torch.flip(x, dims=(2,)) 223 | # Random horizontal Flip 224 | for j in range(hFlip): 225 | x = torch.flip(x, dims=(1,)) 226 | return x 227 | 228 | 229 | def augment_2(generate_gt, crop_size): 230 | c, h, w = generate_gt.shape[1], crop_size[0], crop_size[1] 231 | divid_point_h = crop_size[0] // 2 232 | divid_point_w = crop_size[1] // 2 233 | output_img = torch.zeros(c,h,w) 234 | output_img[:, :divid_point_h, :divid_point_w] = generate_gt[0] 235 | output_img[:, :divid_point_h, divid_point_w:] = generate_gt[1] 236 | output_img[:, divid_point_h:, :divid_point_w] = generate_gt[2] 237 | output_img[:, divid_point_h:, divid_point_w:] = generate_gt[3] 238 | return output_img 239 | 240 | 241 | def shift(inputs, step, nC): 242 | [nC, row, col] = inputs.shape 243 | for i in range(nC): 244 | inputs[i,:,:] = torch.roll(inputs[i,:,:], shifts=step*i, dims=1) 245 | return inputs 246 | 247 | def shift_back(inputs, step, nC): # input [bs,256,310] output [bs, 28, 256, 256] 248 | [row, col] = inputs.shape 249 | output = torch.zeros(nC, row, col - (nC - 1) * step).float() 250 | for i in range(nC): 251 | output[i, :, :] = inputs[:, step * i:step * i + col - (nC - 1) * step] 252 | return output 253 | 254 | def gen_meas_torch(inputs, Phi, step, wave_len, mask_type="mask_3d_shift"): 255 | data = {} 256 | [nC, H, W] = inputs.shape 257 | gt = torch.zeros(nC, H, W+step*(nC-1)).to(inputs.device) 258 | if mask_type == "mask_3d": 259 | gt[:,:,0:W] = Phi * inputs 260 | gt_shift = shift(gt, step=step, nC=wave_len) 261 | if mask_type == "mask_3d_shift": 262 | gt[:,:,0:W] = inputs 263 | gt_shift = shift(gt, step=step, nC=wave_len) 264 | gt_shift = Phi * gt_shift 265 | y = torch.sum(gt_shift, 0) 266 | meas = y / nC * 2 267 | H = shift_back(meas, step=step, nC=wave_len) 268 | 269 | data['Y'] = y 270 | data['H'] = H 271 | 272 | return data 273 | 274 | def shift_batch(inputs, nC = 28, step=2): 275 | [B, nC, row, col] = inputs.shape 276 | outputs = torch.zeros((B, nC, row, col + (nC - 1) * step)).float().to(inputs.device) 277 | for i in range(nC): 278 | outputs[:, i, :, step * i:step * i + col] = inputs[:, i, :, :] 279 | return outputs 280 | 281 | def shift_back_batch(inputs, nC=28, step=2): # input [bs,256,310] output [bs, 28, 256, 256] 282 | [B, row, col] = inputs.shape 283 | output = torch.zeros(B, nC, row, col - (nC - 1) * step).float().to(inputs.device) 284 | for i in range(nC): 285 | output[:, i, :, :] = inputs[:, :, step * i:step * i + col - (nC - 1) * step] 286 | return output 287 | 288 | 289 | def gen_meas_torch_batch(inputs, Phi, step, wave_len, mask_type="mask_3d_shift", with_noise=False): 290 | data = {} 291 | [B, nC, H, W] = inputs.shape 292 | if mask_type == "mask_3d": 293 | modulated_hsi = Phi * inputs 294 | modulated_hsi_shift = shift_batch(modulated_hsi, step=step, nC=wave_len) 295 | if mask_type == "mask_3d_shift": 296 | hsi_shift = shift_batch(inputs, step=step, nC=wave_len) 297 | modulated_hsi_shift = Phi * hsi_shift 298 | 299 | y = torch.sum(modulated_hsi_shift, 1) 300 | 301 | if with_noise: 302 | input = y / nC * 2 * 1.2 303 | # input = y / (y.max() + 1e-7) * 0.9 304 | QE, bit = 0.4, 2048 305 | input_noise = torch.tensor(np.random.binomial((input.cpu().numpy() * bit / QE).astype(int), QE)).float() 306 | input = input_noise / bit 307 | input = input.to(inputs.device) 308 | H = shift_back_batch(input, step=step) 309 | data['Y'] = input 310 | data['H'] = H 311 | else: 312 | meas = y / nC * 2 313 | H = shift_back_batch(meas, step=step) 314 | data['Y'] = y 315 | data['H'] = H 316 | 317 | 318 | 319 | 320 | return data 321 | 322 | 323 | if __name__ == '__main__': 324 | cfg = Box( 325 | { 326 | "DEBUG": True, 327 | "DATASETS": 328 | { 329 | "STEP": 2, 330 | "WAVE_LENS": 28, 331 | "MASK_TYPE": "mask_3d_shift", 332 | "WITH_PAN": False, 333 | "TRAIN": 334 | { 335 | "PATHS" : ["../../../datasets/CSI/cave_1024_28"], 336 | "ITERATION": 1000, 337 | "MASK_PATH": "../../../datasets/CSI/TSA_simu_data/mask_3d_shift.mat", 338 | "RANDOM_MASK": False, 339 | "AUGMENT": True, 340 | }, 341 | "VAL": 342 | { 343 | "PATH": "../../../datasets/CSI/TSA_simu_data/Truth", 344 | "MASK_PATH": "../../../datasets/CSI/TSA_simu_data/mask_3d_shift.mat" 345 | } 346 | }, 347 | "DATALOADER": 348 | { 349 | "BATCH_SIZE": 1 350 | } 351 | } 352 | ) 353 | 354 | train_mask_path = "/Users/shawn/Documents/Code/datasets/CSI/TSA_simu_data/mask_3d_shift.mat" 355 | train_mask = generate_mask_3d_shift(train_mask_path) 356 | print("train_mask: ", train_mask.shape) 357 | 358 | dataset = CSITrainDataset(cfg) 359 | 360 | data = dataset[0] 361 | print("data['hsi'].shape: ", data['hsi'].shape) 362 | print("data['mask'].shape: ", data['mask'].shape) 363 | cv2.imshow("data['hsi'][0]", data['hsi'][0].numpy()) 364 | cv2.imshow("data['mask'][0]", data['mask'][0].numpy()) 365 | 366 | 367 | 368 | val_mask = generate_mask_3d_shift(mask_path=cfg.DATASETS.VAL.MASK_PATH) 369 | val_data = LoadVal(cfg.DATASETS.VAL.PATH) 370 | 371 | val_img = torch.from_numpy(val_data['hsi'][0]).permute(2, 0, 1).float() 372 | data = gen_meas_torch(val_img, val_mask, step=cfg.DATASETS.STEP, wave_len=cfg.DATASETS.WAVE_LENS, mask_type=cfg.DATASETS.MASK_TYPE) 373 | 374 | print("torch.mean(data['H']): ", torch.mean(data['H'])) 375 | print("data['H'].shape: ", data['H'].shape) 376 | cv2.imshow("data['H'][0]", data['H'][0].numpy()) 377 | cv2.imshow("data['H'][-1]", data['H'][-1].numpy()) 378 | cv2.imshow("val img", val_img[0].numpy()) 379 | 380 | 381 | 382 | # path_test = "/Users/shawn/Documents/Code/datasets/CSI/508_real_indoor/Measurements/" 383 | # pan_test = "/Users/shawn/Documents/Code/datasets/CSI/508_real_indoor/Panchromatic/" 384 | # measurements, pans = LoadTestMeas(path_test, pan_test) 385 | # print("test measurement shape: ", measurements.shape) 386 | # print("test pan shape: ", pans.shape) 387 | # meas = measurements[0] 388 | # cv2.imshow("test meas", meas.numpy()) 389 | # shifted_meas = shift_back(meas, cfg.DATASETS.STEP) 390 | # print("shifted_meas.shape: ", shifted_meas.shape) 391 | # cv2.imshow("test real meas[0]", shifted_meas[0].numpy()) 392 | # cv2.imshow("test real meas[-1]", shifted_meas[-1].numpy()) 393 | # cv2.imshow("test real pan", pans[0].numpy()) 394 | 395 | 396 | cv2.waitKey(0) -------------------------------------------------------------------------------- /csi/architectures/dernn_lnlt.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import numbers 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | from torch.nn.init import _calculate_fan_in_and_fan_out 10 | 11 | from einops import rearrange 12 | from torch import einsum 13 | 14 | from box import Box 15 | from fvcore.nn import FlopCountAnalysis 16 | 17 | from csi.data import shift_batch, shift_back_batch, gen_meas_torch_batch 18 | 19 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 20 | def norm_cdf(x): 21 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 22 | 23 | if (mean < a - 2 * std) or (mean > b + 2 * std): 24 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 25 | "The distribution of values may be incorrect.", 26 | stacklevel=2) 27 | with torch.no_grad(): 28 | l = norm_cdf((a - mean) / std) 29 | u = norm_cdf((b - mean) / std) 30 | tensor.uniform_(2 * l - 1, 2 * u - 1) 31 | tensor.erfinv_() 32 | tensor.mul_(std * math.sqrt(2.)) 33 | tensor.add_(mean) 34 | tensor.clamp_(min=a, max=b) 35 | return tensor 36 | 37 | 38 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 39 | # type: (Tensor, float, float, float, float) -> Tensor 40 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 41 | 42 | 43 | 44 | class LocalMSA(nn.Module): 45 | """ 46 | The Local MSA partitions the input into non-overlapping windows of size M × M, treating each pixel within the window as a token, and computes self-attention within the window. 47 | """ 48 | def __init__(self, 49 | dim, 50 | num_heads, 51 | window_size, 52 | ): 53 | super().__init__() 54 | self.dim = dim 55 | self.window_size = window_size 56 | self.num_heads = num_heads 57 | head_dim = dim // num_heads 58 | self.scale = head_dim**-0.5 59 | 60 | 61 | self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=False) 62 | self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=False) 63 | 64 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=True) 65 | 66 | self.pos_emb = nn.Parameter(torch.Tensor(1, num_heads, window_size[0]*window_size[1], window_size[0]*window_size[1])) 67 | trunc_normal_(self.pos_emb) 68 | 69 | 70 | def forward(self, x): 71 | """ 72 | x: [b,c,h,w] 73 | return out: [b,c,h,w] 74 | """ 75 | b, c, h, w = x.shape 76 | 77 | q, k, v = self.qkv_dwconv(self.qkv(x)).chunk(3, dim=1) 78 | 79 | q, k, v = map(lambda t: rearrange(t, 'b c (h b0) (w b1) -> (b h w) (b0 b1) c', 80 | b0=self.window_size[0], b1=self.window_size[1]), (q, k, v)) 81 | 82 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), (q, k, v)) 83 | q *= self.scale 84 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 85 | sim = sim + self.pos_emb 86 | attn = sim.softmax(dim=-1) 87 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 88 | out = rearrange(out, 'b h n d -> b n (h d)') 89 | 90 | out = rearrange(out, '(b h w) (b0 b1) c -> b c (h b0) (w b1)', h=h // self.window_size[0], w=w // self.window_size[1], 91 | b0=self.window_size[0]) 92 | out = self.project_out(out) 93 | 94 | return out 95 | 96 | 97 | 98 | class NonLocalMSA(nn.Module): 99 | """ 100 | The Non-Local MSA divides the input into N × N non-overlapping windows, treating each window as a token, and computes self-attention across the windows. 101 | """ 102 | def __init__(self, 103 | dim, 104 | num_heads, 105 | window_num 106 | ): 107 | super().__init__() 108 | self.dim = dim 109 | self.window_num = window_num 110 | self.num_heads = num_heads 111 | head_dim = dim // num_heads 112 | self.scale = head_dim**-0.5 113 | 114 | 115 | self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=False) 116 | self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=False) 117 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=True) 118 | 119 | 120 | self.pos_emb = nn.Parameter(torch.Tensor(1, num_heads, window_num[0]*window_num[1], window_num[0]*window_num[1])) 121 | trunc_normal_(self.pos_emb) 122 | 123 | 124 | def forward(self, x): 125 | """ 126 | x: [b,c,h,w] 127 | return out: [b,c,h,w] 128 | """ 129 | b, c, h, w = x.shape 130 | 131 | q, k, v = self.qkv_dwconv(self.qkv(x)).chunk(3, dim=1) 132 | 133 | q, k, v = map(lambda t: rearrange(t, 'b c (h b0) (w b1)-> b (h w) (b0 b1 c)', 134 | h=self.window_num[0], w=self.window_num[1]), (q, k, v)) 135 | 136 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), (q, k, v)) 137 | 138 | head_dim = ((h // self.window_num[0]) * (w // self.window_num[1]) * c) / self.num_heads 139 | scale = head_dim ** -0.5 140 | 141 | q *= scale 142 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 143 | 144 | sim = sim + self.pos_emb 145 | attn = sim.softmax(dim=-1) 146 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 147 | 148 | out = rearrange(out, 'b h n d -> b n (h d)') 149 | out = rearrange(out, 'b (h w) (b0 b1 c) -> (b h w) (b0 b1) c', h=self.window_num[0], b0=h // self.window_num[0], b1=w // self.window_num[1]) 150 | 151 | out = rearrange(out, '(b h w) (b0 b1) c -> b c (h b0) (w b1)', h=self.window_num[0], w= self.window_num[1], 152 | b0=h//self.window_num[0]) 153 | out = self.project_out(out) 154 | 155 | 156 | return out 157 | 158 | 159 | class GELU(nn.Module): 160 | def forward(self, x): 161 | return F.gelu(x) 162 | 163 | 164 | class FeedForward(nn.Module): 165 | def __init__(self, dim, mult=4): 166 | super().__init__() 167 | self.net = nn.Sequential( 168 | nn.Conv2d(dim, dim * mult, 1, 1, bias=False), 169 | GELU(), 170 | nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult), 171 | GELU(), 172 | nn.Conv2d(dim * mult, dim, 1, 1, bias=False), 173 | ) 174 | 175 | def forward(self, x): 176 | """ 177 | x: [b, h, w, c] 178 | return out: [b, h, w, c] 179 | """ 180 | out = self.net(x) 181 | return out 182 | 183 | 184 | ## Gated-Dconv Feed-Forward Network (GDFN) 185 | class Gated_Dconv_FeedForward(nn.Module): 186 | def __init__(self, 187 | dim, 188 | ffn_expansion_factor = 2.66 189 | ): 190 | super(Gated_Dconv_FeedForward, self).__init__() 191 | 192 | hidden_features = int(dim*ffn_expansion_factor) 193 | 194 | self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=False) 195 | 196 | self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=True) 197 | 198 | self.act_fn = nn.GELU() 199 | 200 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=False) 201 | 202 | def forward(self, x): 203 | """ 204 | x: [b, c, h, w] 205 | return out: [b, c, h, w] 206 | """ 207 | x = self.project_in(x) 208 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 209 | x = self.act_fn(x1) * x2 210 | x = self.project_out(x) 211 | return x 212 | 213 | 214 | def FFN_FN( 215 | cfg, 216 | ffn_name, 217 | dim 218 | ): 219 | if ffn_name == "Gated_Dconv_FeedForward": 220 | return Gated_Dconv_FeedForward( 221 | dim, 222 | ffn_expansion_factor=cfg.MODEL.DENOISER.DERNN_LNLT.FFN_EXPAND, 223 | ) 224 | elif ffn_name == "FeedForward": 225 | return FeedForward(dim = dim) 226 | 227 | 228 | def to_3d(x): 229 | return rearrange(x, 'b c h w -> b (h w) c') 230 | 231 | def to_4d(x,h,w): 232 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 233 | 234 | 235 | class BiasFree_LayerNorm(nn.Module): 236 | def __init__(self, normalized_shape): 237 | super(BiasFree_LayerNorm, self).__init__() 238 | if isinstance(normalized_shape, numbers.Integral): 239 | normalized_shape = (normalized_shape,) 240 | normalized_shape = torch.Size(normalized_shape) 241 | 242 | assert len(normalized_shape) == 1 243 | 244 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 245 | self.normalized_shape = normalized_shape 246 | 247 | def forward(self, x): 248 | sigma = x.var(-1, keepdim=True, unbiased=False) 249 | return x / torch.sqrt(sigma+1e-5) * self.weight 250 | 251 | 252 | class WithBias_LayerNorm(nn.Module): 253 | def __init__(self, normalized_shape): 254 | super(WithBias_LayerNorm, self).__init__() 255 | if isinstance(normalized_shape, numbers.Integral): 256 | normalized_shape = (normalized_shape,) 257 | normalized_shape = torch.Size(normalized_shape) 258 | 259 | assert len(normalized_shape) == 1 260 | 261 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 262 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 263 | self.normalized_shape = normalized_shape 264 | 265 | def forward(self, x): 266 | mu = x.mean(-1, keepdim=True) 267 | sigma = x.var(-1, keepdim=True, unbiased=False) 268 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 269 | 270 | 271 | class LayerNorm(nn.Module): 272 | def __init__(self, dim, LayerNorm_type): 273 | super(LayerNorm, self).__init__() 274 | if LayerNorm_type =='BiasFree': 275 | self.body = BiasFree_LayerNorm(dim) 276 | else: 277 | self.body = WithBias_LayerNorm(dim) 278 | 279 | def forward(self, x): 280 | # x: (b, c, h, w) 281 | h, w = x.shape[-2:] 282 | return to_4d(self.body(to_3d(x)), h, w) 283 | 284 | 285 | class PreNorm(nn.Module): 286 | def __init__(self, dim, fn, layernorm_type='WithBias'): 287 | super().__init__() 288 | self.fn = fn 289 | self.layernorm_type = layernorm_type 290 | if layernorm_type == 'BiasFree' or layernorm_type == 'WithBias': 291 | self.norm = LayerNorm(dim, layernorm_type) 292 | else: 293 | self.norm = nn.LayerNorm(dim) 294 | 295 | def forward(self, x, *args, **kwargs): 296 | if self.layernorm_type == 'BiasFree' or self.layernorm_type == 'WithBias': 297 | x = self.norm(x) 298 | else: 299 | h, w = x.shape[-2:] 300 | x = to_4d(self.norm(to_3d(x)), h, w) 301 | return self.fn(x, *args, **kwargs) 302 | 303 | 304 | 305 | class LocalNonLocalBlock(nn.Module): 306 | """ 307 | The Local and Non-Local Transformer Block (LNLB) is the most important component. Each LNLB consists of three layer-normalizations (LNs), a Local MSA, a Non-Local MSA, and a GDFN (Zamir et al. 2022). 308 | """ 309 | def __init__(self, 310 | cfg, 311 | dim, 312 | num_heads, 313 | window_size:tuple, 314 | window_num:tuple, 315 | layernorm_type, 316 | num_blocks, 317 | ): 318 | super().__init__() 319 | self.cfg = cfg 320 | self.window_size = window_size 321 | self.window_num = window_num 322 | 323 | self.blocks = nn.ModuleList([]) 324 | for _ in range(num_blocks): 325 | self.blocks.append(nn.ModuleList([ 326 | PreNorm(dim, LocalMSA( 327 | dim = dim, 328 | window_size = window_size, 329 | num_heads = num_heads, 330 | ), 331 | layernorm_type = layernorm_type) if self.cfg.MODEL.DENOISER.DERNN_LNLT.LOCAL else nn.Identity(), 332 | PreNorm(dim, NonLocalMSA( 333 | dim = dim, 334 | num_heads = num_heads, 335 | window_num = window_num, 336 | ), 337 | layernorm_type = layernorm_type) if self.cfg.MODEL.DENOISER.DERNN_LNLT.NON_LOCAL else nn.Identity(), 338 | PreNorm(dim, FFN_FN( 339 | cfg, 340 | ffn_name = cfg.MODEL.DENOISER.DERNN_LNLT.FFN_NAME, 341 | dim = dim 342 | ), 343 | layernorm_type = layernorm_type) 344 | ])) 345 | 346 | 347 | def forward(self, x): 348 | for (local_msa, nonlocal_msa, ffn) in self.blocks: 349 | x = x + local_msa(x) 350 | x = x + nonlocal_msa(x) 351 | x = x + ffn(x) 352 | 353 | return x 354 | 355 | 356 | class DownSample(nn.Module): 357 | def __init__(self, in_channels, bias=False): 358 | super(DownSample, self).__init__() 359 | self.down = nn.Sequential( 360 | nn.Conv2d(in_channels, in_channels * 2, 4, 2, 1, bias=False) 361 | ) 362 | 363 | def forward(self, x): 364 | x = self.down(x) 365 | return x 366 | 367 | class UpSample(nn.Module): 368 | def __init__(self, in_channels, bias=False): 369 | super(UpSample, self).__init__() 370 | self.up = nn.Sequential( 371 | nn.ConvTranspose2d(in_channels, in_channels // 2, stride=2, kernel_size=2, padding=0, output_padding=0) 372 | ) 373 | 374 | def forward(self, x): 375 | x = self.up(x) 376 | return x 377 | 378 | 379 | class LNLT(nn.Module): 380 | """ 381 | The Local and Non-Local Transformer (LNLT) adopts a three-level U-shaped structure, and each level consists of multiple basic units called Local and Non-Local Transformer Blocks (LNLBs). Up- and down-sampling modules are positioned between LNLBs. 382 | """ 383 | def __init__(self, cfg): 384 | super().__init__() 385 | self.cfg = cfg 386 | self.embedding = nn.Conv2d(cfg.MODEL.DENOISER.DERNN_LNLT.IN_DIM, cfg.MODEL.DENOISER.DERNN_LNLT.DIM, kernel_size=3, stride=1, padding=1, bias=False) 387 | 388 | 389 | self.Encoder = nn.ModuleList([ 390 | LocalNonLocalBlock( 391 | cfg = cfg, 392 | dim = cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 0, 393 | num_heads = 2 ** 0, 394 | window_size = cfg.MODEL.DENOISER.DERNN_LNLT.WINDOW_SIZE, 395 | window_num = cfg.MODEL.DENOISER.DERNN_LNLT.WINDOW_NUM, 396 | layernorm_type = cfg.MODEL.DENOISER.DERNN_LNLT.LAYERNORM_TYPE, 397 | num_blocks = cfg.MODEL.DENOISER.DERNN_LNLT.NUM_BLOCKS[0], 398 | ), 399 | LocalNonLocalBlock( 400 | cfg = cfg, 401 | dim = cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 1, 402 | num_heads = 2 ** 1, 403 | window_size = cfg.MODEL.DENOISER.DERNN_LNLT.WINDOW_SIZE, 404 | window_num = cfg.MODEL.DENOISER.DERNN_LNLT.WINDOW_NUM, 405 | layernorm_type = cfg.MODEL.DENOISER.DERNN_LNLT.LAYERNORM_TYPE, 406 | num_blocks = cfg.MODEL.DENOISER.DERNN_LNLT.NUM_BLOCKS[1], 407 | ), 408 | ]) 409 | 410 | self.BottleNeck = LocalNonLocalBlock( 411 | cfg = cfg, 412 | dim = cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 2, 413 | num_heads = 2 ** 2, 414 | window_size = cfg.MODEL.DENOISER.DERNN_LNLT.WINDOW_SIZE, 415 | window_num = cfg.MODEL.DENOISER.DERNN_LNLT.WINDOW_NUM, 416 | layernorm_type = cfg.MODEL.DENOISER.DERNN_LNLT.LAYERNORM_TYPE, 417 | num_blocks = cfg.MODEL.DENOISER.DERNN_LNLT.NUM_BLOCKS[2], 418 | ) 419 | 420 | self.Decoder = nn.ModuleList([ 421 | LocalNonLocalBlock( 422 | cfg = cfg, 423 | dim = cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 1, 424 | num_heads = 2 ** 1, 425 | window_size = cfg.MODEL.DENOISER.DERNN_LNLT.WINDOW_SIZE, 426 | window_num = cfg.MODEL.DENOISER.DERNN_LNLT.WINDOW_NUM, 427 | layernorm_type = cfg.MODEL.DENOISER.DERNN_LNLT.LAYERNORM_TYPE, 428 | num_blocks = cfg.MODEL.DENOISER.DERNN_LNLT.NUM_BLOCKS[3], 429 | ), 430 | LocalNonLocalBlock( 431 | cfg = cfg, 432 | dim = cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 0, 433 | num_heads = 2 ** 0, 434 | window_size = cfg.MODEL.DENOISER.DERNN_LNLT.WINDOW_SIZE, 435 | window_num = cfg.MODEL.DENOISER.DERNN_LNLT.WINDOW_NUM, 436 | layernorm_type = cfg.MODEL.DENOISER.DERNN_LNLT.LAYERNORM_TYPE, 437 | num_blocks = cfg.MODEL.DENOISER.DERNN_LNLT.NUM_BLOCKS[4], 438 | ) 439 | ]) 440 | 441 | self.Downs = nn.ModuleList([ 442 | DownSample(cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 0), 443 | DownSample(cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 1) 444 | ]) 445 | 446 | self.Ups = nn.ModuleList([ 447 | UpSample(cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 2), 448 | UpSample(cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 1) 449 | ]) 450 | 451 | self.fusions = nn.ModuleList([ 452 | nn.Conv2d( 453 | in_channels = cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 2, 454 | out_channels = cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 1, 455 | kernel_size = 1, 456 | stride = 1, 457 | padding = 0, 458 | bias = False 459 | ), 460 | nn.Conv2d( 461 | in_channels = cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 1, 462 | out_channels = cfg.MODEL.DENOISER.DERNN_LNLT.DIM * 2 ** 0, 463 | kernel_size = 1, 464 | stride = 1, 465 | padding = 0, 466 | bias = False 467 | ) 468 | ]) 469 | 470 | self.mapping = nn.Conv2d(cfg.MODEL.DENOISER.DERNN_LNLT.DIM, cfg.MODEL.DENOISER.DERNN_LNLT.OUT_DIM, kernel_size=3, stride=1, padding=1, bias=False) 471 | 472 | 473 | def forward(self, x): 474 | b, c, h_inp, w_inp = x.shape 475 | hb, wb = 16, 16 476 | pad_h = (hb - h_inp % hb) % hb 477 | pad_w = (wb - w_inp % wb) % wb 478 | x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect') 479 | 480 | 481 | x1 = self.embedding(x) 482 | res1 = self.Encoder[0](x1) 483 | 484 | x2 = self.Downs[0](res1) 485 | res2 = self.Encoder[1](x2) 486 | 487 | x4 = self.Downs[1](res2) 488 | res4 = self.BottleNeck(x4) 489 | 490 | dec_res2 = self.Ups[0](res4) # dim * 2 ** 2 -> dim * 2 ** 1 491 | dec_res2 = torch.cat([dec_res2, res2], dim=1) # dim * 2 ** 2 492 | dec_res2 = self.fusions[0](dec_res2) # dim * 2 ** 2 -> dim * 2 ** 1 493 | dec_res2 = self.Decoder[0](dec_res2) 494 | 495 | dec_res1 = self.Ups[1](dec_res2) # dim * 2 ** 1 -> dim * 2 ** 0 496 | dec_res1 = torch.cat([dec_res1, res1], dim=1) # dim * 2 ** 1 497 | dec_res1 = self.fusions[1](dec_res1) # dim * 2 ** 1 -> dim * 2 ** 0 498 | dec_res1 = self.Decoder[1](dec_res1) 499 | 500 | if self.cfg.MODEL.DENOISER.DERNN_LNLT.WITH_NOISE_LEVEL: 501 | out = self.mapping(dec_res1) + x[:, 1:, :, :] 502 | else: 503 | out = self.mapping(dec_res1) + x 504 | 505 | 506 | return out[:, :, :h_inp, :w_inp] 507 | 508 | 509 | def PWDWPWConv(in_channels, out_channels): 510 | return nn.Sequential( 511 | nn.Conv2d(in_channels, 64, 1, 1, 0, bias=True), 512 | nn.GELU(), 513 | nn.Conv2d(64, 64, 3, 1, 1, bias=True, groups=64), 514 | nn.GELU(), 515 | nn.Conv2d(64, out_channels, 1, 1, 0, bias=False) 516 | ) 517 | 518 | def A(x, Phi): 519 | B, nC, H, W = x.shape 520 | temp = x * Phi 521 | y = torch.sum(temp, 1) 522 | return y 523 | 524 | def At(y, Phi): 525 | temp = torch.unsqueeze(y, 1).repeat(1, Phi.shape[1], 1, 1) 526 | x = temp * Phi 527 | return x 528 | 529 | 530 | def shift_3d(inputs, step=2): 531 | [B, C, H, W] = inputs.shape 532 | temp = torch.zeros((B, C, H, W+(C-1)*step)).to(inputs.device) 533 | temp[:, :, :, :W] = inputs 534 | for i in range(C): 535 | temp[:,i,:,:] = torch.roll(temp[:,i,:,:], shifts=step*i, dims=2) 536 | return temp 537 | 538 | def shift_back_3d(inputs,step=2): 539 | [bs, nC, row, col] = inputs.shape 540 | for i in range(nC): 541 | inputs[:,i,:,:] = torch.roll(inputs[:,i,:,:], shifts=(-1)*step*i, dims=2) 542 | return inputs 543 | 544 | 545 | class DegradationEstimation(nn.Module): 546 | """ 547 | The Degradation Estimation Network (DEN) is proposed to estimate degradation-related parameters from the input of the current recurrent step and with reference to the sensing matrix. 548 | """ 549 | def __init__(self, cfg): 550 | super().__init__() 551 | self.cfg = cfg 552 | self.DL = nn.Sequential( 553 | PWDWPWConv(self.cfg.DATASETS.WAVE_LENS*2, self.cfg.DATASETS.WAVE_LENS*2), 554 | PWDWPWConv(self.cfg.DATASETS.WAVE_LENS*2, self.cfg.DATASETS.WAVE_LENS), 555 | ) 556 | self.down_sample = nn.Conv2d(self.cfg.DATASETS.WAVE_LENS, self.cfg.DATASETS.WAVE_LENS*2, 3, 2, 1, bias=True) # (B, 64, H, W) -> (B, 64, H//2, W//2) 557 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 558 | self.mlp = nn.Sequential( 559 | nn.Conv2d(self.cfg.DATASETS.WAVE_LENS*2, self.cfg.DATASETS.WAVE_LENS*2, 1, padding=0, bias=True), 560 | nn.ReLU(inplace=True), 561 | nn.Conv2d(self.cfg.DATASETS.WAVE_LENS*2, self.cfg.DATASETS.WAVE_LENS*2, 1, padding=0, bias=True), 562 | nn.ReLU(inplace=True), 563 | nn.Conv2d(self.cfg.DATASETS.WAVE_LENS*2, 2, 1, padding=0, bias=True), 564 | nn.Softplus()) 565 | self.relu = nn.ReLU(inplace=True) 566 | 567 | 568 | def forward(self, y, phi): 569 | 570 | inp = torch.cat([phi, y], dim=1) 571 | phi_r = self.DL(inp) 572 | 573 | phi = phi + phi_r 574 | 575 | x = self.down_sample(self.relu(phi_r)) 576 | x = self.avg_pool(x) 577 | x = self.mlp(x) + 1e-6 578 | mu = x[:, 0, :, :] 579 | noise_level = x[:, 1, :, :] 580 | 581 | return phi, mu, noise_level[:, None, :, :] 582 | 583 | 584 | class DERNN_LNLT(nn.Module): 585 | """ 586 | The DERNN-LNLT unfolds the HQS algorithm within the MAP framework and transfors the DUN into an RNN by sharing parameters across stages. 587 | 588 | Then, the DERNN-LNLT integrate the Degradation Estimation Network into the RNN, which estimates the degradation matrix for the data subproblem and the noise level for the prior subproblem by residual learning with reference to the sensing matrix. 589 | 590 | Subsequently, the Local and Non-Local Transformer (LNLT) utilizes the Local and Non-Local Multi-head Self-Attention (MSA) to effectively exploit both local and non-local HSIs priors. 591 | 592 | Finally, incorporating the LNLT into the DERNN as the denoiser for the prior subproblem leads to the proposed DERNN-LNLT. 593 | """ 594 | def __init__(self, cfg): 595 | super().__init__() 596 | self.cfg = cfg 597 | 598 | self.fusion = nn.Conv2d(cfg.DATASETS.WAVE_LENS*2, cfg.DATASETS.WAVE_LENS, 1, padding=0, bias=True) 599 | 600 | self.DP = nn.ModuleList([ 601 | DegradationEstimation(cfg) for _ in range(cfg.MODEL.DENOISER.DERNN_LNLT.STAGE) 602 | ]) if not cfg.MODEL.DENOISER.DERNN_LNLT.SHARE_PARAMS else DegradationEstimation(cfg) 603 | self.PP = nn.ModuleList([ 604 | LNLT(cfg) for _ in range(cfg.MODEL.DENOISER.DERNN_LNLT.STAGE) 605 | ]) if not cfg.MODEL.DENOISER.DERNN_LNLT.SHARE_PARAMS else LNLT(cfg) 606 | 607 | 608 | self.apply(self._init_weights) 609 | 610 | def _init_weights(self, m): 611 | if isinstance(m, nn.Linear): 612 | trunc_normal_(m.weight, std=.02) 613 | if isinstance(m, nn.Linear) and m.bias is not None: 614 | nn.init.constant_(m.bias, 0) 615 | elif isinstance(m, nn.Conv2d): 616 | trunc_normal_(m.weight, std=.02) 617 | if isinstance(m, nn.Conv2d) and m.bias is not None: 618 | nn.init.constant_(m.bias, 0) 619 | elif isinstance(m, nn.LayerNorm): 620 | nn.init.constant_(m.bias, 0) 621 | nn.init.constant_(m.weight, 1.0) 622 | 623 | def initial(self, y, Phi): 624 | """ 625 | :param y: [b,256,310] 626 | :param Phi: [b,28,256,310] 627 | :return: temp: [b,28,256,310]; alpha: [b, num_iterations]; beta: [b, num_iterations] 628 | """ 629 | nC = self.cfg.DATASETS.WAVE_LENS 630 | step = self.cfg.DATASETS.STEP 631 | bs, nC, row, col = Phi.shape 632 | y_shift = torch.zeros(bs, nC, row, col).to(y.device).float() 633 | for i in range(nC): 634 | y_shift[:, i, :, step * i:step * i + col - (nC - 1) * step] = y[:, :, step * i:step * i + col - (nC - 1) * step] 635 | z = self.fusion(torch.cat([y_shift, Phi], dim=1)) 636 | return z 637 | 638 | def prepare_input(self, data): 639 | hsi = data['hsi'] 640 | mask = data['mask'] 641 | 642 | YH = gen_meas_torch_batch(hsi, mask, step=self.cfg.DATASETS.STEP, wave_len=self.cfg.DATASETS.WAVE_LENS, mask_type=self.cfg.DATASETS.MASK_TYPE, with_noise=self.cfg.DATASETS.TRAIN.WITH_NOISE) 643 | 644 | data['Y'] = YH['Y'] 645 | data['H'] = YH['H'] 646 | 647 | return data 648 | 649 | 650 | def forward_train(self, data): 651 | y = data['Y'] 652 | phi = data['mask'] 653 | x0 = data['H'] 654 | 655 | z = self.initial(y, phi) 656 | 657 | 658 | B, C, H, W = phi.shape 659 | B, C, H_, W_ = x0.shape 660 | 661 | for i in range(self.cfg.MODEL.DENOISER.DERNN_LNLT.STAGE): 662 | Phi, mu, noise_level = self.DP[i](z, phi) if not self.cfg.MODEL.DENOISER.DERNN_LNLT.SHARE_PARAMS else self.DP(z, phi) 663 | 664 | if not self.cfg.MODEL.DENOISER.DERNN_LNLT.WITH_DL: 665 | Phi = phi 666 | if not self.cfg.MODEL.DENOISER.DERNN_LNLT.WITH_MU: 667 | mu = torch.FloatTensor([1e-6]).to(y.device) 668 | 669 | Phi_s = torch.sum(Phi**2,1) 670 | Phi_s[Phi_s==0] = 1 671 | Phi_z = A(z, Phi) 672 | x = z + At(torch.div(y-Phi_z,mu+Phi_s), Phi) 673 | x = shift_back_3d(x)[:, :, :, :W_] 674 | noise_level_repeat = noise_level.repeat(1,1,x.shape[2], x.shape[3]) 675 | if not self.cfg.MODEL.DENOISER.DERNN_LNLT.WITH_NOISE_LEVEL: 676 | z = self.PP[i](x) if not self.cfg.MODEL.DENOISER.DERNN_LNLT.SHARE_PARAMS else self.PP(x) 677 | else: 678 | z = self.PP[i](torch.cat([noise_level_repeat, x],dim=1)) if not self.cfg.MODEL.DENOISER.DERNN_LNLT.SHARE_PARAMS else self.PP(torch.cat([noise_level_repeat, x],dim=1)) 679 | z = shift_3d(z) 680 | 681 | 682 | z = shift_back_3d(z)[:, :, :, :W_] 683 | 684 | out = z[:, :, :, :W_] 685 | 686 | return out 687 | 688 | def forward_test(self, data): 689 | y = data['Y'] 690 | phi = data['mask'] 691 | x0 = data['H'] 692 | 693 | z = self.initial(y, phi) 694 | 695 | 696 | B, C, H, W = phi.shape 697 | B, C, H_, W_ = x0.shape 698 | 699 | for i in range(self.cfg.MODEL.DENOISER.DERNN_LNLT.STAGE): 700 | Phi, mu, noise_level = self.DP[i](z, phi) if not self.cfg.MODEL.DENOISER.DERNN_LNLT.SHARE_PARAMS else self.DP(z, phi) 701 | 702 | if not self.cfg.MODEL.DENOISER.DERNN_LNLT.WITH_DL: 703 | Phi = phi 704 | if not self.cfg.MODEL.DENOISER.DERNN_LNLT.WITH_MU: 705 | mu = torch.FloatTensor([1e-6]).to(y.device) 706 | 707 | Phi_s = torch.sum(Phi**2,1) 708 | Phi_s[Phi_s==0] = 1 709 | Phi_z = A(z, Phi) 710 | x = z + At(torch.div(y-Phi_z,mu+Phi_s), Phi) 711 | x = shift_back_3d(x)[:, :, :, :W_] 712 | noise_level_repeat = noise_level.repeat(1,1,x.shape[2], x.shape[3]) 713 | if not self.cfg.MODEL.DENOISER.DERNN_LNLT.WITH_NOISE_LEVEL: 714 | z = self.PP[i](x) if not self.cfg.MODEL.DENOISER.DERNN_LNLT.SHARE_PARAMS else self.PP(x) 715 | else: 716 | z = self.PP[i](torch.cat([noise_level_repeat, x],dim=1)) if not self.cfg.MODEL.DENOISER.DERNN_LNLT.SHARE_PARAMS else self.PP(torch.cat([noise_level_repeat, x],dim=1)) 717 | z = shift_3d(z) 718 | 719 | 720 | z = shift_back_3d(z)[:, :, :, :W_] 721 | 722 | out = z[:, :, :, :W_] 723 | 724 | return out 725 | 726 | def forward(self, data): 727 | if self.training: 728 | data = self.prepare_input(data) 729 | x = self.forward_train(data) 730 | 731 | else: 732 | x = self.forward_test(data) 733 | 734 | return x 735 | 736 | 737 | --------------------------------------------------------------------------------