├── README.md ├── figure ├── CRM.png ├── network.png └── poster.png ├── requirements.txt ├── test ├── DRCR.py ├── __pycache__ │ ├── AWAN.cpython-38.pyc │ ├── DRCR.cpython-38.pyc │ └── DRCR_Net.cpython-38.pyc ├── compute_mrae.py └── test.py └── train ├── DRCR.py ├── __pycache__ ├── DRCR.cpython-38.pyc ├── dataset.cpython-38.pyc └── utils.cpython-38.pyc ├── dataset.py ├── main.py ├── train_data_preprocess.py ├── utils.py └── valid_data_preprocess.py /README.md: -------------------------------------------------------------------------------- 1 | # DRCR Net: Dense Residual Channel Re-calibration Network with Non-local Purification for Spectral Super Resolution (CVPRW 2022) 2 | [![3rd](https://img.shields.io/badge/3rd%20place-NTIRE__2022__Challenge__on__Spectral__Reconstruction__from__RGB-orange)](https://codalab.lisn.upsaclay.fr/competitions/721#learn_the_details) 3 | [![cvprw](https://img.shields.io/badge/CVPRW-DRCR%20Net-green)](https://openaccess.thecvf.com/content/CVPR2022W/NTIRE/papers/Li_DRCR_Net_Dense_Residual_Channel_Re-Calibration_Network_With_Non-Local_Purification_CVPRW_2022_paper.pdf) 4 | ![visitors](https://visitor-badge.glitch.me/badge?page_id=dusongcheng/DRCR-Net) 5 | 6 | [Jiaojiao Li](https://scholar.google.com/citations?user=Ccu3-acAAAAJ&hl=zh-CN&oi=ao), [Songcheng Du](https://github.com/dusongcheng), [Chaoxiong Wu](https://scholar.google.com/citations?user=PIsTkkEAAAAJ&hl=zh-CN&oi=ao), [Yihong Leng](), [Rui Song](https://scholar.google.com/citations?user=_SKooBYAAAAJ&hl=zh-CN&oi=sra) and [Yunsong Li]() 7 | 8 |
9 | 10 | > **Abstract:** Spectral super resolution (SSR) aims to reconstruct the 3D hyperspectral signal from a 2D RGB image, which is prosperous with the proliferation of Convolutional Neural Networks (CNNs) and increased access to RGB/hyperspectral datasets. Nevertheless, most CNN-based spectral reconstruction (SR) algorithms can only perform high reconstruction accuracy when the input RGB image is relatively ‘clean' with foregone spectral response functions. Unfortunately, in the real world, images are contaminated by mixed noise, bad illumination conditions, compression, artifacts etc. and the existing state-of-the-art (SOTA) methods are no longer working well. To conquer these drawbacks, we propose a novel dense residual channel re-calibration network (DRCR Net) with non-local purification for achieving robust SSR results, which first performs the interference removal through a non-local purification module (NPM) to refine the RGB inputs. To be specific, as the main component of backbone, the dense residual channel re-calibration (DRCR) block is cascaded with an encoder-decoder paradigm through several cross-layer dense residual connections, to capture the deep spatial-spectral interactions, which further improve the generalization ability of the network effectively. Furthermore, we customize dual channel re-calibration modules (CRMs) which are embedded in each DRCR block to adaptively re-calibrate channel-wise feature response for pursuing high-fidelity spectral recovery. In the NTIRE 2022 Spectral Reconstruction Challenge, our entry obtained the 3rd ranking. 11 |
12 | 13 | 14 | 15 | ## DRCR Net Framework 16 | 17 | 18 | 19 | 20 | ## CRM 21 | 22 | 23 | 24 | 25 | ## Train 26 | 1. #### Download the dataset. 27 | 28 | - Download the training spectral images ([Google Drive](https://drive.google.com/file/d/1FQBfDd248dCKClR-BpX5V2drSbeyhKcq/view)) 29 | - Download the training RGB images ([Google Drive](https://drive.google.com/file/d/1A4GUXhVc5k5d_79gNvokEtVPG290qVkd/view)) 30 | - Download the validation spectral images ([Google Drive](https://drive.google.com/file/d/12QY8LHab3gzljZc3V6UyHgBee48wh9un/view)) 31 | - Download the validation RGB images ([Google Drive](https://drive.google.com/file/d/19vBR_8Il1qcaEZsK42aGfvg5lCuvLh1A/view)) 32 | 33 | Put all downloaded files to `/DRCR-Net-master/Dataset/`, and this repo is collected as the following form: 34 | ```shell 35 | |--DRCR-Net-master 36 | |--figures 37 | |--test 38 | |--train 39 | |--Dataset 40 | |--Train_spectral 41 | |--ARAD_1K_0001.mat 42 | |--ARAD_1K_0002.mat 43 | : 44 | |--ARAD_1K_0900.mat 45 | |--Train_RGB 46 | |--ARAD_1K_0001.jpg 47 | |--ARAD_1K_0002.jpg 48 | : 49 | |--ARAD_1K_0900.jpg 50 | |--Valid_soectral 51 | |--ARAD_1K_0901.mat 52 | |--ARAD_1K_0902.mat 53 | : 54 | |--ARAD_1K_0950.mat 55 | |--Valid_RGB 56 | |--ARAD_1K_0901.jpg 57 | |--ARAD_1K_0902.jpg 58 | : 59 | |--ARAD_1K_0950.jpg 60 | ``` 61 | 2. #### Data Preprocess. 62 | ```shell 63 | cd /DRCR-Net-master/train/ 64 | 65 | # Getting the prepared train data by run: 66 | python train_data_preprocess.py --data_path '../Dataset' --patch_size 128 --stride 64 --train_data_path './dataset/Train' 67 | 68 | # Getting the prepared valid data by run: 69 | python valid_data_preprocess.py --data_path '../Dataset' --valid_data_path './dataset/Valid' 70 | ``` 71 | 3. #### Training. 72 | ```shell 73 | python main.py 74 | ``` 75 | The data generated during training will be recorded in `/RealWorldResults/`. 76 | ## Test 77 | ```shell 78 | cd /DRCR-Net-master/test/ 79 | python test.py --RGB_dir '../Dataset/Valid_RGB' --model_dir './model/model.pth' --result_dir './test_results' 80 | 81 | # The MRAE and RMSE indicators can be obtained by run: 82 | python compute_mrae.py --path_rec './test_results' --path_gt '../Dataset/Valid_spectral' 83 | ``` 84 | - Download the model ([Google Drive](https://drive.google.com/file/d/1UJfP6cw9b1EWCGHPnsGEV8AYlr9JTVYC/view?usp=sharing) / [Baidu Disk](https://pan.baidu.com/s/1rHc80ZRg7m893_hCObYAlQ), code: `drcr`)) 85 | - Download the reconstructed valid spectral images ([Google Drive](https://drive.google.com/file/d/1gdF-W4OkKN7Z345ayWzsOuaBMh0p5lcm/view?usp=sharing) / [Baidu Disk](https://pan.baidu.com/s/1Wd3NQfVp4bA_IBMT5dayhg), code: `drcr`)) 86 | 87 | ## Citation 88 | If you find this code helpful, please kindly cite: 89 | ```shell 90 | # DRCR Net 91 | @InProceedings{Li_2022_CVPR, 92 | author = {Li, Jiaojiao and Du, Songcheng and Wu, Chaoxiong and Leng, Yihong and Song, Rui and Li, Yunsong}, 93 | title = {DRCR Net: Dense Residual Channel Re-Calibration Network With Non-Local Purification for Spectral Super Resolution}, 94 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 95 | month = {June}, 96 | year = {2022}, 97 | pages = {1259-1268} 98 | } 99 | ``` 100 | ## CVPRW poster 101 | ![poster](./figure/poster.png) 102 | -------------------------------------------------------------------------------- /figure/CRM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/figure/CRM.png -------------------------------------------------------------------------------- /figure/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/figure/network.png -------------------------------------------------------------------------------- /figure/poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/figure/poster.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Environment requirement 2 | - Anaconda3 3 | - pytorch1.9.0 4 | - hdf5storage 5 | - opencv-python 6 | - h5py 7 | - tqdm -------------------------------------------------------------------------------- /test/DRCR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | 7 | 8 | class Conv3x3(nn.Module): 9 | def __init__(self, in_dim, out_dim, kernel_size, stride, dilation=1): 10 | super(Conv3x3, self).__init__() 11 | reflect_padding = int(dilation * (kernel_size - 1) / 2) 12 | self.reflection_pad = nn.ReflectionPad2d(reflect_padding) 13 | self.conv2d = nn.Conv2d(in_dim, out_dim, kernel_size, stride, dilation=dilation, bias=False) 14 | 15 | def forward(self, x): 16 | out = self.reflection_pad(x) 17 | out = self.conv2d(out) 18 | return out 19 | 20 | class Conv2D(nn.Module): 21 | def __init__(self, in_channel=256, out_channel=8): 22 | super(Conv2D, self).__init__() 23 | self.guide_conv2D = nn.Conv2d(in_channel, out_channel, 3, 1, 1) 24 | 25 | def forward(self, x): 26 | spatial_guidance = self.guide_conv2D(x) 27 | return spatial_guidance 28 | 29 | class Conv2dLayer(nn.Module): 30 | def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False): 31 | super(Conv2dLayer, self).__init__() 32 | # Initialize the padding scheme 33 | if pad_type == 'reflect': 34 | self.pad = nn.ReflectionPad2d(padding) 35 | elif pad_type == 'replicate': 36 | self.pad = nn.ReplicationPad2d(padding) 37 | elif pad_type == 'zero': 38 | self.pad = nn.ZeroPad2d(padding) 39 | else: 40 | assert 0, "Unsupported padding type: {}".format(pad_type) 41 | 42 | # Initialize the normalization type 43 | if norm == 'bn': 44 | self.norm = nn.BatchNorm2d(out_channels) 45 | elif norm == 'in': 46 | self.norm = nn.InstanceNorm2d(out_channels) 47 | elif norm == 'none': 48 | self.norm = None 49 | else: 50 | assert 0, "Unsupported normalization: {}".format(norm) 51 | 52 | # Initialize the activation funtion 53 | if activation == 'relu': 54 | self.activation = nn.ReLU(inplace = True) 55 | elif activation == 'lrelu': 56 | self.activation = nn.LeakyReLU(0.2, inplace = True) 57 | elif activation == 'prelu': 58 | self.activation = nn.PReLU() 59 | elif activation == 'selu': 60 | self.activation = nn.SELU(inplace = True) 61 | elif activation == 'tanh': 62 | self.activation = nn.Tanh() 63 | elif activation == 'sigmoid': 64 | self.activation = nn.Sigmoid() 65 | elif activation == 'none': 66 | self.activation = None 67 | else: 68 | assert 0, "Unsupported activation: {}".format(activation) 69 | 70 | # Initialize the convolution layers 71 | if sn: 72 | pass 73 | else: 74 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation) 75 | def forward(self, x): 76 | x = self.pad(x) 77 | x = self.conv2d(x) 78 | if self.norm: 79 | x = self.norm(x) 80 | if self.activation: 81 | x = self.activation(x) 82 | return x 83 | 84 | class NPM(nn.Module): 85 | def __init__(self, in_channel): 86 | super(NPM, self).__init__() 87 | self.in_channel = in_channel 88 | self.activation = nn.LeakyReLU(0.2, inplace = True) 89 | self.conv0_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1) 90 | self.conv0_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0) 91 | self.conv_0_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1) 92 | self.conv2_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1) 93 | self.conv2_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0) 94 | self.conv_2_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1) 95 | self.conv4_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1) 96 | self.conv4_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0) 97 | self.conv_4_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1) 98 | 99 | self.conv_cat = nn.Conv2d(in_channel*3, in_channel, 3, 1, 1) 100 | 101 | def forward(self, x): 102 | 103 | x_0 = x 104 | x_2 = F.avg_pool2d(x, 2, 2) 105 | x_4 = F.avg_pool2d(x_2, 2, 2) 106 | 107 | x_0 = torch.cat([self.conv0_33(x_0), self.conv0_11(x_0)], 1) 108 | x_0 = self.activation(self.conv_0_cat(x_0)) 109 | 110 | x_2 = torch.cat([self.conv2_33(x_2), self.conv2_11(x_2)], 1) 111 | x_2 = F.interpolate(self.activation(self.conv_2_cat(x_2)), scale_factor=2, mode='bilinear') 112 | 113 | x_4 = torch.cat([self.conv2_33(x_4), self.conv2_11(x_4)], 1) 114 | x_4 = F.interpolate(self.activation(self.conv_4_cat(x_4)), scale_factor=4, mode='bilinear') 115 | 116 | x = x + self.activation(self.conv_cat(torch.cat([x_0, x_2, x_4], 1))) 117 | return x 118 | 119 | class CRM(nn.Module): 120 | def __init__(self, channel, reduction = 8): 121 | super(CRM, self).__init__() 122 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 123 | self.fc = nn.Sequential( 124 | nn.Linear(channel, channel // reduction, bias = False), 125 | nn.ReLU(inplace = True), 126 | nn.Linear(channel // reduction, channel // reduction, bias = False), 127 | nn.ReLU(inplace = True), 128 | nn.Linear(channel // reduction, channel, bias = False), 129 | nn.Sigmoid() 130 | ) 131 | 132 | def forward(self, x): 133 | b, c, _, _ = x.size() 134 | y = self.avg_pool(x).view(b, c) 135 | y = self.fc(y).view(b, c, 1, 1) 136 | return x * y.expand_as(x) 137 | 138 | class DRCR_Block(nn.Module): 139 | def __init__(self, in_channels, latent_channels, kernel_size = 3, stride = 1, padding = 1, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False): 140 | super(DRCR_Block, self).__init__() 141 | # dense convolutions 142 | self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, 143 | activation, norm, sn) 144 | self.conv2 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, 145 | activation, norm, sn) 146 | self.conv3 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, 147 | activation, norm, sn) 148 | self.conv4 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type, 149 | activation, norm, sn) 150 | self.conv5 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type, 151 | activation, norm, sn) 152 | self.conv6 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type, 153 | activation, norm, sn) 154 | # self.cspn2_guide = GMLayer(in_channels) 155 | # self.cspn2 = Affinity_Propagate_Channel() 156 | self.se1 = CRM(in_channels) 157 | self.se2 = CRM(in_channels) 158 | 159 | def forward(self, x): 160 | x1 = self.conv1(x) 161 | x2 = self.conv2(x1) 162 | x3 = self.conv3(x2) 163 | # guidance2 = self.cspn2_guide(x3) 164 | # x3_2 = self.cspn2(guidance2, x3) 165 | x3_2 = self.se1(x) 166 | x4 = self.conv4(torch.cat((x3, x3_2), 1)) 167 | x5 = self.conv5(torch.cat((x2, x4), 1)) 168 | x6 = self.conv6(torch.cat((x1, x5), 1))+self.se2(x3_2) 169 | return x6 170 | 171 | class DRCR(nn.Module): 172 | def __init__(self, inplanes=3, planes=31, channels=200, n_DRBs=8): 173 | super(DRCR, self).__init__() 174 | self.input_conv2D = Conv3x3(inplanes, channels, 3, 1) 175 | self.input_prelu2D = nn.PReLU() 176 | self.head_conv2D = Conv3x3(channels, channels, 3, 1) 177 | self.denosing = NPM(channels) 178 | self.backbone = nn.ModuleList( 179 | [DRCR_Block(channels, channels) for _ in range(n_DRBs)]) 180 | self.tail_conv2D = Conv3x3(channels, channels, 3, 1) 181 | self.output_prelu2D = nn.PReLU() 182 | self.output_conv2D = Conv3x3(channels, planes, 3, 1) 183 | 184 | def forward(self, x): 185 | out = self.DRN2D(x) 186 | return out 187 | 188 | def DRN2D(self, x): 189 | out = self.input_prelu2D(self.input_conv2D(x)) 190 | out = self.head_conv2D(out) 191 | out = self.denosing(out) 192 | 193 | for i, block in enumerate(self.backbone): 194 | out = block(out) 195 | 196 | out = self.tail_conv2D(out) 197 | out = self.output_conv2D(self.output_prelu2D(out)) 198 | return out 199 | 200 | 201 | 202 | 203 | 204 | 205 | if __name__ == "__main__": 206 | # import os 207 | # os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' 208 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 209 | input_tensor = torch.rand(1, 3, 128, 128) 210 | model = DRCR(3, 31, 100, 10) 211 | # model = nn.DataParallel(model).cuda() 212 | with torch.no_grad(): 213 | output_tensor = model(input_tensor) 214 | print(output_tensor.size()) 215 | print('Parameters number is ', sum(param.numel() for param in model.parameters())) 216 | print(torch.__version__) 217 | 218 | 219 | 220 | 221 | 222 | -------------------------------------------------------------------------------- /test/__pycache__/AWAN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/test/__pycache__/AWAN.cpython-38.pyc -------------------------------------------------------------------------------- /test/__pycache__/DRCR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/test/__pycache__/DRCR.cpython-38.pyc -------------------------------------------------------------------------------- /test/__pycache__/DRCR_Net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/test/__pycache__/DRCR_Net.cpython-38.pyc -------------------------------------------------------------------------------- /test/compute_mrae.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import hdf5storage as hdf5 4 | import numpy as np 5 | import argparse 6 | 7 | 8 | def compute_MRAE(gt, rec): 9 | gt_hyper = gt 10 | rec_hyper = rec 11 | error = np.abs(rec_hyper - gt_hyper) / gt_hyper 12 | mrae = np.mean(error.reshape(-1)) 13 | return mrae 14 | 15 | def compute_RMSE(gt, rec): 16 | error = np.power(gt - rec, 2) 17 | rmse = np.sqrt(np.mean(error)) 18 | return rmse 19 | 20 | def main(): 21 | path_rec = opt.path_rec 22 | path_gt = opt.path_gt 23 | 24 | name_rec_list = glob.glob(os.path.join(path_rec, '*.mat')) 25 | name_gt_list = glob.glob(os.path.join(path_gt, '*.mat')) 26 | name_rec_list.sort() 27 | name_gt_list.sort() 28 | 29 | mrae_all = [] 30 | rmse_all = [] 31 | 32 | for i in range(len(name_gt_list)): 33 | hyper_rec = hdf5.loadmat(name_rec_list[i])['cube'] 34 | hyper_gt = hdf5.loadmat(name_gt_list[i])['cube'] 35 | if hyper_gt.min()<= 0.: 36 | print(os.path.basename(name_gt_list[i]), end=' ') 37 | print('This file is not suitable for compute the MRAE indicator.') 38 | continue 39 | hyper_rec = np.clip(hyper_rec, 0,1) 40 | mrae = compute_MRAE(hyper_gt, hyper_rec) 41 | rmse = compute_RMSE(hyper_gt, hyper_rec) 42 | print(os.path.basename(name_gt_list[i]), end=' ') 43 | print('mrae: '+str(mrae)+', rmse: '+str(rmse)) 44 | mrae_all.append(mrae) 45 | rmse_all.append(rmse) 46 | print('The average mrae is: '+str(sum(mrae_all)/len(mrae_all))) 47 | print('The average rmse is: '+str(sum(rmse_all)/len(rmse_all))) 48 | 49 | if __name__ == '__main__': 50 | parser = argparse.ArgumentParser(description="SSR_test") 51 | parser.add_argument("--path_rec", type=str, default='./test_results', help="The path of the reconstructed valid spectral data.") 52 | parser.add_argument("--path_gt", type=str, default='../Dataset/Valid_spectral', help="The path of the ground truth valid spectral data.") 53 | opt = parser.parse_args() 54 | main() -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import cv2 5 | from DRCR import DRCR 6 | import glob 7 | import hdf5storage as hdf5 8 | import time 9 | import argparse 10 | 11 | 12 | def get_reconstruction_gpu(input, model): 13 | """As the limited GPU memory split the input.""" 14 | model.eval() 15 | var_input = input.cuda() 16 | with torch.no_grad(): 17 | start_time = time.time() 18 | var_output1 = model(var_input[:,:,:-2,:]) 19 | var_output2 = model(var_input[:,:,2:,:]) 20 | var_output = torch.cat([var_output1, var_output2[:,:,-2:,:]], 2) 21 | end_time = time.time() 22 | 23 | return end_time-start_time, var_output.cpu() 24 | 25 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' 26 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 27 | 28 | parser = argparse.ArgumentParser(description="SSR_test") 29 | parser.add_argument("--RGB_dir", type=str, default='../Dataset/Valid_RGB', help="absolute Input_RGB_path") 30 | parser.add_argument("--model_dir", type=str, default='./model/model.pth', help="absolute Model_path") 31 | parser.add_argument("--result_dir", type=str, default='./test_results', help="absolute Save_Result_path") 32 | opt = parser.parse_args() 33 | 34 | img_path = opt.RGB_dir 35 | model_path = opt.model_dir 36 | result_path = opt.result_dir 37 | 38 | var_name = 'cube' 39 | # save results 40 | if not os.path.exists(result_path): 41 | os.makedirs(result_path) 42 | model = DRCR(3, 31, 100, 10) 43 | save_point = torch.load(model_path) 44 | model_param = save_point['state_dict'] 45 | model_dict = {} 46 | for k1, k2 in zip(model.state_dict(), model_param): 47 | model_dict[k1] = model_param[k2] 48 | model.load_state_dict(model_dict) 49 | model = model.cuda() 50 | 51 | img_path_name = glob.glob(os.path.join(img_path, '*.jpg')) 52 | img_path_name.sort() 53 | 54 | for i in range(len(img_path_name)): 55 | # load rgb images 56 | print(img_path_name[i].split('/')[-1]) 57 | rgb = cv2.imread(img_path_name[i]) 58 | rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) 59 | rgb = np.float32(rgb) 60 | rgb = rgb / rgb.max() 61 | rgb = np.expand_dims(np.transpose(rgb, [2, 0, 1]), axis=0).copy() 62 | rgb = torch.from_numpy(rgb).float() 63 | use_time, temp_hyper = get_reconstruction_gpu(rgb, model) 64 | img_res = temp_hyper.numpy() * 1.0 65 | img_res = np.transpose(np.squeeze(img_res), [1, 2, 0]) 66 | img_res_limits = np.minimum(img_res, 1.0) 67 | img_res_limits = np.maximum(img_res_limits, 0) 68 | 69 | mat_name = img_path_name[i].split('/')[-1][:-4] + '.mat' 70 | mat_dir = os.path.join(result_path, mat_name) 71 | hdf5.savemat(mat_dir, {var_name: img_res}, format='7.3', store_python_metadata=True) -------------------------------------------------------------------------------- /train/DRCR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | 7 | 8 | class Conv3x3(nn.Module): 9 | def __init__(self, in_dim, out_dim, kernel_size, stride, dilation=1): 10 | super(Conv3x3, self).__init__() 11 | reflect_padding = int(dilation * (kernel_size - 1) / 2) 12 | self.reflection_pad = nn.ReflectionPad2d(reflect_padding) 13 | self.conv2d = nn.Conv2d(in_dim, out_dim, kernel_size, stride, dilation=dilation, bias=False) 14 | 15 | def forward(self, x): 16 | out = self.reflection_pad(x) 17 | out = self.conv2d(out) 18 | return out 19 | 20 | class Conv2D(nn.Module): 21 | def __init__(self, in_channel=256, out_channel=8): 22 | super(Conv2D, self).__init__() 23 | self.guide_conv2D = nn.Conv2d(in_channel, out_channel, 3, 1, 1) 24 | 25 | def forward(self, x): 26 | spatial_guidance = self.guide_conv2D(x) 27 | return spatial_guidance 28 | 29 | class Conv2dLayer(nn.Module): 30 | def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False): 31 | super(Conv2dLayer, self).__init__() 32 | # Initialize the padding scheme 33 | if pad_type == 'reflect': 34 | self.pad = nn.ReflectionPad2d(padding) 35 | elif pad_type == 'replicate': 36 | self.pad = nn.ReplicationPad2d(padding) 37 | elif pad_type == 'zero': 38 | self.pad = nn.ZeroPad2d(padding) 39 | else: 40 | assert 0, "Unsupported padding type: {}".format(pad_type) 41 | 42 | # Initialize the normalization type 43 | if norm == 'bn': 44 | self.norm = nn.BatchNorm2d(out_channels) 45 | elif norm == 'in': 46 | self.norm = nn.InstanceNorm2d(out_channels) 47 | elif norm == 'none': 48 | self.norm = None 49 | else: 50 | assert 0, "Unsupported normalization: {}".format(norm) 51 | 52 | # Initialize the activation funtion 53 | if activation == 'relu': 54 | self.activation = nn.ReLU(inplace = True) 55 | elif activation == 'lrelu': 56 | self.activation = nn.LeakyReLU(0.2, inplace = True) 57 | elif activation == 'prelu': 58 | self.activation = nn.PReLU() 59 | elif activation == 'selu': 60 | self.activation = nn.SELU(inplace = True) 61 | elif activation == 'tanh': 62 | self.activation = nn.Tanh() 63 | elif activation == 'sigmoid': 64 | self.activation = nn.Sigmoid() 65 | elif activation == 'none': 66 | self.activation = None 67 | else: 68 | assert 0, "Unsupported activation: {}".format(activation) 69 | 70 | # Initialize the convolution layers 71 | if sn: 72 | pass 73 | else: 74 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation) 75 | def forward(self, x): 76 | x = self.pad(x) 77 | x = self.conv2d(x) 78 | if self.norm: 79 | x = self.norm(x) 80 | if self.activation: 81 | x = self.activation(x) 82 | return x 83 | 84 | class NPM(nn.Module): 85 | def __init__(self, in_channel): 86 | super(NPM, self).__init__() 87 | self.in_channel = in_channel 88 | self.activation = nn.LeakyReLU(0.2, inplace = True) 89 | self.conv0_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1) 90 | self.conv0_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0) 91 | self.conv_0_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1) 92 | self.conv2_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1) 93 | self.conv2_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0) 94 | self.conv_2_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1) 95 | self.conv4_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1) 96 | self.conv4_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0) 97 | self.conv_4_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1) 98 | 99 | self.conv_cat = nn.Conv2d(in_channel*3, in_channel, 3, 1, 1) 100 | 101 | def forward(self, x): 102 | 103 | x_0 = x 104 | x_2 = F.avg_pool2d(x, 2, 2) 105 | x_4 = F.avg_pool2d(x_2, 2, 2) 106 | 107 | x_0 = torch.cat([self.conv0_33(x_0), self.conv0_11(x_0)], 1) 108 | x_0 = self.activation(self.conv_0_cat(x_0)) 109 | 110 | x_2 = torch.cat([self.conv2_33(x_2), self.conv2_11(x_2)], 1) 111 | x_2 = F.interpolate(self.activation(self.conv_2_cat(x_2)), scale_factor=2, mode='bilinear') 112 | 113 | x_4 = torch.cat([self.conv2_33(x_4), self.conv2_11(x_4)], 1) 114 | x_4 = F.interpolate(self.activation(self.conv_4_cat(x_4)), scale_factor=4, mode='bilinear') 115 | 116 | x = x + self.activation(self.conv_cat(torch.cat([x_0, x_2, x_4], 1))) 117 | return x 118 | 119 | class CRM(nn.Module): 120 | def __init__(self, channel, reduction = 8): 121 | super(CRM, self).__init__() 122 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 123 | self.fc = nn.Sequential( 124 | nn.Linear(channel, channel // reduction, bias = False), 125 | nn.ReLU(inplace = True), 126 | nn.Linear(channel // reduction, channel // reduction, bias = False), 127 | nn.ReLU(inplace = True), 128 | nn.Linear(channel // reduction, channel, bias = False), 129 | nn.Sigmoid() 130 | ) 131 | 132 | def forward(self, x): 133 | b, c, _, _ = x.size() 134 | y = self.avg_pool(x).view(b, c) 135 | y = self.fc(y).view(b, c, 1, 1) 136 | return x * y.expand_as(x) 137 | 138 | class DRCR_Block(nn.Module): 139 | def __init__(self, in_channels, latent_channels, kernel_size = 3, stride = 1, padding = 1, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False): 140 | super(DRCR_Block, self).__init__() 141 | # dense convolutions 142 | self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, 143 | activation, norm, sn) 144 | self.conv2 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, 145 | activation, norm, sn) 146 | self.conv3 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, 147 | activation, norm, sn) 148 | self.conv4 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type, 149 | activation, norm, sn) 150 | self.conv5 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type, 151 | activation, norm, sn) 152 | self.conv6 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type, 153 | activation, norm, sn) 154 | # self.cspn2_guide = GMLayer(in_channels) 155 | # self.cspn2 = Affinity_Propagate_Channel() 156 | self.se1 = CRM(in_channels) 157 | self.se2 = CRM(in_channels) 158 | 159 | def forward(self, x): 160 | x1 = self.conv1(x) 161 | x2 = self.conv2(x1) 162 | x3 = self.conv3(x2) 163 | # guidance2 = self.cspn2_guide(x3) 164 | # x3_2 = self.cspn2(guidance2, x3) 165 | x3_2 = self.se1(x) 166 | x4 = self.conv4(torch.cat((x3, x3_2), 1)) 167 | x5 = self.conv5(torch.cat((x2, x4), 1)) 168 | x6 = self.conv6(torch.cat((x1, x5), 1))+self.se2(x3_2) 169 | return x6 170 | 171 | class DRCR(nn.Module): 172 | def __init__(self, inplanes=3, planes=31, channels=200, n_DRBs=8): 173 | super(DRCR, self).__init__() 174 | self.input_conv2D = Conv3x3(inplanes, channels, 3, 1) 175 | self.input_prelu2D = nn.PReLU() 176 | self.head_conv2D = Conv3x3(channels, channels, 3, 1) 177 | self.denosing = NPM(channels) 178 | self.backbone = nn.ModuleList( 179 | [DRCR_Block(channels, channels) for _ in range(n_DRBs)]) 180 | self.tail_conv2D = Conv3x3(channels, channels, 3, 1) 181 | self.output_prelu2D = nn.PReLU() 182 | self.output_conv2D = Conv3x3(channels, planes, 3, 1) 183 | 184 | def forward(self, x): 185 | out = self.DRN2D(x) 186 | return out 187 | 188 | def DRN2D(self, x): 189 | out = self.input_prelu2D(self.input_conv2D(x)) 190 | out = self.head_conv2D(out) 191 | out = self.denosing(out) 192 | 193 | for i, block in enumerate(self.backbone): 194 | out = block(out) 195 | 196 | out = self.tail_conv2D(out) 197 | out = self.output_conv2D(self.output_prelu2D(out)) 198 | return out 199 | 200 | 201 | 202 | 203 | 204 | 205 | if __name__ == "__main__": 206 | # import os 207 | # os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' 208 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 209 | input_tensor = torch.rand(1, 3, 128, 128) 210 | model = DRCR(3, 31, 100, 10) 211 | # model = nn.DataParallel(model).cuda() 212 | with torch.no_grad(): 213 | output_tensor = model(input_tensor) 214 | print(output_tensor.size()) 215 | print('Parameters number is ', sum(param.numel() for param in model.parameters())) 216 | print(torch.__version__) 217 | 218 | 219 | 220 | 221 | 222 | -------------------------------------------------------------------------------- /train/__pycache__/DRCR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/train/__pycache__/DRCR.cpython-38.pyc -------------------------------------------------------------------------------- /train/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/train/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /train/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/train/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /train/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import h5py 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as udata 6 | import glob 7 | import os 8 | 9 | class HyperDatasetValid(udata.Dataset): 10 | def __init__(self, mode='valid'): 11 | if mode != 'valid': 12 | raise Exception("Invalid mode!", mode) 13 | data_path = './dataset/Valid' 14 | data_names = glob.glob(os.path.join(data_path, '*.mat')) 15 | self.keys = data_names 16 | self.keys.sort() 17 | 18 | def __len__(self): 19 | return len(self.keys) 20 | 21 | def __getitem__(self, index): 22 | mat = h5py.File(self.keys[index], 'r') 23 | hyper = np.float32(np.array(mat['rad'])) 24 | hyper = np.transpose(hyper, [2, 1, 0]) 25 | hyper = torch.Tensor(hyper)[:,:-2,:] 26 | rgb = np.float32(np.array(mat['rgb'])) 27 | rgb = np.transpose(rgb, [2, 1, 0]) 28 | rgb = torch.Tensor(rgb)[:,:-2,:] 29 | mat.close() 30 | return rgb, hyper 31 | 32 | 33 | class HyperDatasetTrain(udata.Dataset): 34 | def __init__(self, mode='train'): 35 | if mode != 'train': 36 | raise Exception("Invalid mode!", mode) 37 | data_path = './dataset/Train' 38 | data_names1 = glob.glob(os.path.join(data_path, '*.mat')) 39 | 40 | self.keys = data_names1 41 | random.shuffle(self.keys) 42 | # self.keys.sort() 43 | 44 | def __len__(self): 45 | return len(self.keys) 46 | 47 | def __getitem__(self, index): 48 | mat = h5py.File(self.keys[index], 'r') 49 | hyper = np.float32(np.array(mat['rad'])) 50 | hyper = np.transpose(hyper, [2, 1, 0]) 51 | hyper = torch.Tensor(hyper) 52 | rgb = np.float32(np.array(mat['rgb'])) 53 | rgb = np.transpose(rgb, [2, 1, 0]) 54 | rgb = torch.Tensor(rgb) 55 | mat.close() 56 | return rgb, hyper 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /train/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | import torch.optim as optim 5 | import torch.backends.cudnn as cudnn 6 | from torch.utils.data import DataLoader 7 | from torch.autograd import Variable 8 | import os 9 | import time 10 | import random 11 | from dataset import HyperDatasetValid, HyperDatasetTrain 12 | from DRCR import DRCR 13 | from utils import AverageMeter, initialize_logger, save_checkpoint, record_loss, Loss_train, Loss_valid 14 | from tqdm import tqdm 15 | import warnings 16 | warnings.filterwarnings('ignore') 17 | 18 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 20 | 21 | parser = argparse.ArgumentParser(description="SSR") 22 | parser.add_argument("--batchSize", type=int, default=16, help="batch size") 23 | parser.add_argument("--end_epoch", type=int, default=130+1, help="number of epochs") 24 | parser.add_argument("--init_lr", type=float, default=1e-4, help="initial learning rate") 25 | parser.add_argument("--decay_power", type=float, default=1.5, help="decay power") 26 | parser.add_argument("--max_iter", type=float, default=300000, help="max_iter") # Needs to be adjusted with the adjustment of batchSize, number of train samples 27 | parser.add_argument("--outf", type=str, default="RealWorldResults", help='path log files') 28 | opt = parser.parse_args() 29 | 30 | 31 | def main(): 32 | cudnn.benchmark = True 33 | # Load dataset 34 | print("\nloading dataset ...") 35 | train_data = HyperDatasetTrain(mode='train') 36 | print("Train set samples: ", len(train_data)) 37 | val_data = HyperDatasetValid(mode='valid') 38 | print("Validation set samples: ", len(val_data)) 39 | # Data Loader (Input Pipeline) 40 | train_loader1 = DataLoader(dataset=train_data, batch_size=opt.batchSize, shuffle=True, num_workers=10, pin_memory=False, drop_last=True) 41 | train_loader = [train_loader1] 42 | val_loader = DataLoader(dataset=val_data, batch_size=1, shuffle=False, num_workers=2, pin_memory=False) 43 | 44 | # Model 45 | print("\nbuilding models_baseline ...") 46 | model = DRCR(3, 31, 100, 10) 47 | print('Parameters number is ', sum(param.numel() for param in model.parameters())) 48 | criterion_train = Loss_train() 49 | criterion_valid = Loss_valid() 50 | if torch.cuda.device_count() > 1: 51 | model = nn.DataParallel(model) # batchsize integer times 52 | if torch.cuda.is_available(): 53 | model.cuda() 54 | criterion_train.cuda() 55 | criterion_valid.cuda() 56 | 57 | # Parameters, Loss and Optimizer 58 | start_epoch = 0 59 | iteration = 0 60 | record_val_loss = 1000 61 | optimizer = optim.Adam(model.parameters(), lr=opt.init_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) 62 | 63 | # Record 64 | if not os.path.exists(opt.outf): 65 | os.makedirs(opt.outf) 66 | loss_csv = open(os.path.join(opt.outf, 'loss.csv'), 'a+') 67 | log_dir = os.path.join(opt.outf, 'train.log') 68 | logger = initialize_logger(log_dir) 69 | 70 | # Resume 71 | resume_file = '' 72 | if resume_file: 73 | if os.path.isfile(resume_file): 74 | print("=> loading checkpoint '{}'".format(resume_file)) 75 | checkpoint = torch.load(resume_file) 76 | start_epoch = checkpoint['epoch'] 77 | iteration = checkpoint['iter'] 78 | model.load_state_dict(checkpoint['state_dict']) 79 | optimizer.load_state_dict(checkpoint['optimizer']) 80 | 81 | # Start epoch 82 | for epoch in range(start_epoch+1, opt.end_epoch): 83 | start_time = time.time() 84 | train_loss, iteration, lr = train(train_loader, model, criterion_train, optimizer, epoch, iteration, opt.init_lr, opt.decay_power) 85 | val_loss = validate(val_loader, model, criterion_valid) 86 | # Save model 87 | if torch.abs(val_loss - record_val_loss) < 0.0001 or val_loss < record_val_loss: 88 | save_checkpoint(opt.outf, epoch, iteration, model, optimizer) 89 | if val_loss < record_val_loss: 90 | record_val_loss = val_loss 91 | end_time = time.time() 92 | epoch_time = end_time - start_time 93 | print("Epoch [%02d], Iter[%06d], Time:%.9f, learning rate : %.9f, Train Loss: %.9f Test Loss: %.9f " 94 | % (epoch, iteration, epoch_time, lr, train_loss, val_loss)) 95 | # save loss 96 | record_loss(loss_csv,epoch, iteration, epoch_time, lr, train_loss, val_loss) 97 | logger.info("Epoch [%02d], Iter[%06d], Time:%.9f, learning rate : %.9f, Train Loss: %.9f Test Loss: %.9f " 98 | % (epoch, iteration, epoch_time, lr, train_loss, val_loss)) 99 | 100 | 101 | # Training 102 | def train(train_loader, model, criterion, optimizer, epoch, iteration, init_lr, decay_power): 103 | model.train() 104 | random.shuffle(train_loader) 105 | losses = AverageMeter() 106 | for k, train_data_loader in (enumerate(train_loader)): 107 | for i, (images, labels) in tqdm(enumerate(train_data_loader)): 108 | labels = labels.cuda() 109 | images = images.cuda() 110 | images = Variable(images) 111 | labels = Variable(labels) 112 | # Decaying Learning Rate 113 | lr = poly_lr_scheduler(optimizer, init_lr, iteration, max_iter=opt.max_iter, power=decay_power) 114 | iteration = iteration + 1 115 | # Forward + Backward + Optimize 116 | output = model(images) 117 | loss = criterion(output, labels) 118 | loss_all = loss 119 | optimizer.zero_grad() 120 | loss_all.backward() 121 | optimizer.step() 122 | losses.update(loss.data) 123 | # print('[Epoch:%02d],[Process:%d/%d],[iter:%d],lr=%.9f,train_losses.avg=%.9f' 124 | # % (epoch, k+1, len(train_loader), iteration, lr, losses.avg)) 125 | 126 | return losses.avg, iteration, lr 127 | 128 | 129 | # Validate 130 | def validate(val_loader, model, criterion): 131 | model.eval() 132 | losses = AverageMeter() 133 | for i, (input, target) in enumerate(val_loader): 134 | input = input.cuda() 135 | target = target.cuda() 136 | with torch.no_grad(): 137 | # compute output 138 | output = model(input) 139 | loss = criterion(output, target) 140 | # record loss 141 | losses.update(loss.data) 142 | 143 | return losses.avg 144 | 145 | 146 | # Learning rate 147 | def poly_lr_scheduler(optimizer, init_lr, iteraion, lr_decay_iter=1, max_iter=100, power=0.9): 148 | """Polynomial decay of learning rate 149 | :param init_lr is base learning rate 150 | :param iter is a current iteration 151 | :param lr_decay_iter how frequently decay occurs, default is 1 152 | :param max_iter is number of maximum iterations 153 | :param power is a polymomial power 154 | 155 | """ 156 | if iteraion % lr_decay_iter or iteraion > max_iter: 157 | return optimizer 158 | 159 | lr = init_lr*(1 - iteraion/max_iter)**power 160 | for param_group in optimizer.param_groups: 161 | param_group['lr'] = lr 162 | 163 | return lr 164 | 165 | 166 | if __name__ == '__main__': 167 | main() 168 | print(torch.__version__) 169 | -------------------------------------------------------------------------------- /train/train_data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import numpy as np 5 | import argparse 6 | import hdf5storage as hdf5 7 | import tqdm 8 | 9 | parser = argparse.ArgumentParser(description="SpectralSR") 10 | parser.add_argument("--data_path", type=str, default='../Dataset', help="data path") 11 | parser.add_argument("--patch_size", type=int, default=128, help="data patch size") 12 | parser.add_argument("--stride", type=int, default=64, help="data patch stride") 13 | parser.add_argument("--train_data_path", type=str, default='./dataset/Train', help="preprocess_data_path") 14 | opt = parser.parse_args() 15 | 16 | 17 | def main(): 18 | if not os.path.exists(opt.train_data_path): 19 | os.makedirs(opt.train_data_path) 20 | 21 | process_data(patch_size=opt.patch_size, stride=opt.stride, mode='train') 22 | 23 | 24 | def normalize(data, max_val, min_val): 25 | return (data-min_val)/(np.float32(max_val-min_val)) 26 | 27 | 28 | def Im2Patch(img, win, stride=1): 29 | k = 0 30 | endc = img.shape[0] 31 | endw = img.shape[1] 32 | endh = img.shape[2] 33 | patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride] 34 | TotalPatNum = patch.shape[1] * patch.shape[2] 35 | Y = np.zeros([endc, win*win,TotalPatNum], np.float32) 36 | for i in range(win): 37 | for j in range(win): 38 | patch = img[:, i:endw-win+i+1:stride, j:endh-win+j+1:stride] 39 | Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum) 40 | k = k + 1 41 | return Y.reshape([endc, win, win, TotalPatNum]) 42 | 43 | 44 | def process_data(patch_size, stride, mode): 45 | if mode == 'train': 46 | print("\nprocess training set ...\n") 47 | patch_num = 1 48 | filenames_hyper = glob.glob(os.path.join(opt.data_path, 'Train_spectral', '*.mat')) 49 | filenames_rgb = glob.glob(os.path.join(opt.data_path, 'Train_RGB', '*.jpg')) 50 | filenames_hyper.sort() 51 | filenames_rgb.sort() 52 | print(len(filenames_rgb)) 53 | # for k in range(1): # make small dataset 54 | for k in tqdm.tqdm(range(len(filenames_rgb))): 55 | print([filenames_rgb[k][-16:]]) 56 | 57 | mat = hdf5.loadmat(filenames_hyper[k]) 58 | hyper = np.float32(np.array(mat['cube'])) 59 | hyper = np.transpose(hyper, [2, 0, 1]) 60 | if hyper.min() <= 0: 61 | print('This file contains non-positive values and is not suitable for Training!') 62 | continue 63 | hyper = normalize(hyper, max_val=1., min_val=0.) 64 | # load rgb image 65 | rgb = cv2.imread(filenames_rgb[k]) 66 | rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) 67 | rgb = np.transpose(rgb, [2, 0, 1]) 68 | rgb = normalize(np.float32(rgb), max_val=rgb.max(), min_val=0.) 69 | # creat patches 70 | patches_hyper = Im2Patch(hyper, win=patch_size, stride=stride) 71 | patches_rgb = Im2Patch(rgb, win=patch_size, stride=stride) 72 | for j in range(patches_rgb.shape[3]): 73 | print("generate training sample #%d" % patch_num) 74 | sub_hyper = patches_hyper[:, :, :, j] 75 | sub_rgb = patches_rgb[:, :, :, j] 76 | train_data_path = os.path.join(opt.train_data_path, 'train'+str(patch_num)+'.mat') 77 | print(train_data_path) 78 | hdf5.savemat(train_data_path, {'rad': sub_hyper}, format='7.3') 79 | hdf5.savemat(train_data_path, {'rgb': sub_rgb}, format='7.3') 80 | patch_num += 1 81 | 82 | print("\ntraining set: # samples %d\n" % (patch_num-1)) 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | 88 | -------------------------------------------------------------------------------- /train/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import logging 6 | import numpy as np 7 | import os 8 | import hdf5storage 9 | 10 | 11 | class AverageMeter(object): 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | 28 | def initialize_logger(file_dir): 29 | logger = logging.getLogger() 30 | fhandler = logging.FileHandler(filename=file_dir, mode='a') 31 | formatter = logging.Formatter('%(asctime)s - %(message)s',"%Y-%m-%d %H:%M:%S") 32 | fhandler.setFormatter(formatter) 33 | logger.addHandler(fhandler) 34 | logger.setLevel(logging.INFO) 35 | return logger 36 | 37 | 38 | def save_checkpoint(model_path, epoch, iteration, model, optimizer): 39 | state = { 40 | 'epoch': epoch, 41 | 'iter': iteration, 42 | 'state_dict': model.state_dict(), 43 | 'optimizer': optimizer.state_dict(), 44 | } 45 | 46 | torch.save(state, os.path.join(model_path, 'net_%depoch.pth' % epoch)) 47 | 48 | 49 | def save_matv73(mat_name, var_name, var): 50 | hdf5storage.savemat(mat_name, {var_name: var}, format='7.3', store_python_metadata=True) 51 | 52 | 53 | def record_loss(loss_csv,epoch, iteration, epoch_time, lr, train_loss, test_loss): 54 | """ Record many results.""" 55 | loss_csv.write('{},{},{},{},{},{}\n'.format(epoch, iteration, epoch_time, lr, train_loss, test_loss)) 56 | loss_csv.flush() 57 | loss_csv.close 58 | 59 | 60 | class Loss_train(nn.Module): 61 | def __init__(self): 62 | super(Loss_train, self).__init__() 63 | 64 | def forward(self, outputs, label): 65 | error = torch.abs(outputs - label) / label 66 | # error = torch.abs(outputs - label) 67 | mrae = torch.mean(error.view(-1)) 68 | return mrae 69 | 70 | 71 | class Loss_valid(nn.Module): 72 | def __init__(self): 73 | super(Loss_valid, self).__init__() 74 | 75 | def forward(self, outputs, label): 76 | error = torch.abs(outputs - label) / label 77 | # error = torch.abs(outputs - label) 78 | mrae = torch.mean(error.view(-1)) 79 | return mrae 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /train/valid_data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import cv2 4 | import glob 5 | import numpy as np 6 | import argparse 7 | import hdf5storage 8 | 9 | parser = argparse.ArgumentParser(description="SpectralSR") 10 | parser.add_argument("--data_path", type=str, default='../Dataset', help="data path") 11 | parser.add_argument("--valid_data_path", type=str, default='./dataset/Valid', help="preprocess_data_path") 12 | 13 | 14 | opt = parser.parse_args() 15 | 16 | 17 | def main(): 18 | if not os.path.exists(opt.valid_data_path): 19 | os.makedirs(opt.valid_data_path) 20 | 21 | process_data(mode='valid') 22 | 23 | 24 | def normalize(data, max_val, min_val): 25 | return (data-min_val)/(max_val-min_val) 26 | 27 | 28 | def Im2Patch(img, win, stride=1): 29 | k = 0 30 | endc = img.shape[0] 31 | endw = img.shape[1] 32 | endh = img.shape[2] 33 | patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride] 34 | TotalPatNum = patch.shape[1] * patch.shape[2] 35 | Y = np.zeros([endc, win*win, TotalPatNum], np.float32) 36 | for i in range(win): 37 | for j in range(win): 38 | patch = img[:, i:endw-win+i+1:stride, j:endh-win+j+1:stride] 39 | Y[:, k, :] = np.array(patch[:]).reshape(endc, TotalPatNum) 40 | k = k + 1 41 | return Y.reshape([endc, win, win, TotalPatNum]) 42 | 43 | 44 | def process_data(mode): 45 | if mode == 'valid': 46 | print("\nprocess valid set ...\n") 47 | patch_num = 1 48 | filenames_hyper = glob.glob(os.path.join(opt.data_path, 'Valid_spectral', '*.mat')) 49 | filenames_rgb = glob.glob(os.path.join(opt.data_path, 'Valid_RGB', '*.jpg')) 50 | filenames_hyper.sort() 51 | filenames_rgb.sort() 52 | # for k in range(1): # make small dataset 53 | for k in range(len(filenames_rgb)): 54 | print([filenames_rgb[k]]) 55 | # load hyperspectral image 56 | mat = hdf5storage.loadmat(filenames_hyper[k]) 57 | hyper = np.float32(np.array(mat['cube'])) 58 | hyper = np.transpose(hyper, [2, 0, 1]) 59 | if hyper.min() <= 0: 60 | print('This file contains non-positive values and is not suitable for Testing!') 61 | continue 62 | hyper = normalize(hyper, max_val=1., min_val=0.) 63 | # load rgb image 64 | rgb = cv2.imread(filenames_rgb[k]) 65 | rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) 66 | rgb = np.transpose(rgb, [2, 0, 1]) 67 | rgb = normalize(np.float32(rgb), max_val=rgb.max(), min_val=0.) 68 | valid_data_path = os.path.join(opt.valid_data_path, 'valid' + str(patch_num) + '.mat') 69 | print(valid_data_path) 70 | hdf5storage.savemat(valid_data_path, {'rad': hyper}, format='7.3') 71 | hdf5storage.savemat(valid_data_path, {'rgb': rgb}, format='7.3') 72 | patch_num += 1 73 | 74 | print("\ntraining set: # samples %d\n" % (patch_num-1)) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | 80 | --------------------------------------------------------------------------------