├── README.md
├── figure
├── CRM.png
├── network.png
└── poster.png
├── requirements.txt
├── test
├── DRCR.py
├── __pycache__
│ ├── AWAN.cpython-38.pyc
│ ├── DRCR.cpython-38.pyc
│ └── DRCR_Net.cpython-38.pyc
├── compute_mrae.py
└── test.py
└── train
├── DRCR.py
├── __pycache__
├── DRCR.cpython-38.pyc
├── dataset.cpython-38.pyc
└── utils.cpython-38.pyc
├── dataset.py
├── main.py
├── train_data_preprocess.py
├── utils.py
└── valid_data_preprocess.py
/README.md:
--------------------------------------------------------------------------------
1 | # DRCR Net: Dense Residual Channel Re-calibration Network with Non-local Purification for Spectral Super Resolution (CVPRW 2022)
2 | [](https://codalab.lisn.upsaclay.fr/competitions/721#learn_the_details)
3 | [](https://openaccess.thecvf.com/content/CVPR2022W/NTIRE/papers/Li_DRCR_Net_Dense_Residual_Channel_Re-Calibration_Network_With_Non-Local_Purification_CVPRW_2022_paper.pdf)
4 | 
5 |
6 | [Jiaojiao Li](https://scholar.google.com/citations?user=Ccu3-acAAAAJ&hl=zh-CN&oi=ao), [Songcheng Du](https://github.com/dusongcheng), [Chaoxiong Wu](https://scholar.google.com/citations?user=PIsTkkEAAAAJ&hl=zh-CN&oi=ao), [Yihong Leng](), [Rui Song](https://scholar.google.com/citations?user=_SKooBYAAAAJ&hl=zh-CN&oi=sra) and [Yunsong Li]()
7 |
8 |
9 |
10 | > **Abstract:** Spectral super resolution (SSR) aims to reconstruct the 3D hyperspectral signal from a 2D RGB image, which is prosperous with the proliferation of Convolutional Neural Networks (CNNs) and increased access to RGB/hyperspectral datasets. Nevertheless, most CNN-based spectral reconstruction (SR) algorithms can only perform high reconstruction accuracy when the input RGB image is relatively ‘clean' with foregone spectral response functions. Unfortunately, in the real world, images are contaminated by mixed noise, bad illumination conditions, compression, artifacts etc. and the existing state-of-the-art (SOTA) methods are no longer working well. To conquer these drawbacks, we propose a novel dense residual channel re-calibration network (DRCR Net) with non-local purification for achieving robust SSR results, which first performs the interference removal through a non-local purification module (NPM) to refine the RGB inputs. To be specific, as the main component of backbone, the dense residual channel re-calibration (DRCR) block is cascaded with an encoder-decoder paradigm through several cross-layer dense residual connections, to capture the deep spatial-spectral interactions, which further improve the generalization ability of the network effectively. Furthermore, we customize dual channel re-calibration modules (CRMs) which are embedded in each DRCR block to adaptively re-calibrate channel-wise feature response for pursuing high-fidelity spectral recovery. In the NTIRE 2022 Spectral Reconstruction Challenge, our entry obtained the 3rd ranking.
11 |
12 |
13 |
14 |
15 | ## DRCR Net Framework
16 |
17 |
18 |
19 |
20 | ## CRM
21 |
22 |
23 |
24 |
25 | ## Train
26 | 1. #### Download the dataset.
27 |
28 | - Download the training spectral images ([Google Drive](https://drive.google.com/file/d/1FQBfDd248dCKClR-BpX5V2drSbeyhKcq/view))
29 | - Download the training RGB images ([Google Drive](https://drive.google.com/file/d/1A4GUXhVc5k5d_79gNvokEtVPG290qVkd/view))
30 | - Download the validation spectral images ([Google Drive](https://drive.google.com/file/d/12QY8LHab3gzljZc3V6UyHgBee48wh9un/view))
31 | - Download the validation RGB images ([Google Drive](https://drive.google.com/file/d/19vBR_8Il1qcaEZsK42aGfvg5lCuvLh1A/view))
32 |
33 | Put all downloaded files to `/DRCR-Net-master/Dataset/`, and this repo is collected as the following form:
34 | ```shell
35 | |--DRCR-Net-master
36 | |--figures
37 | |--test
38 | |--train
39 | |--Dataset
40 | |--Train_spectral
41 | |--ARAD_1K_0001.mat
42 | |--ARAD_1K_0002.mat
43 | :
44 | |--ARAD_1K_0900.mat
45 | |--Train_RGB
46 | |--ARAD_1K_0001.jpg
47 | |--ARAD_1K_0002.jpg
48 | :
49 | |--ARAD_1K_0900.jpg
50 | |--Valid_soectral
51 | |--ARAD_1K_0901.mat
52 | |--ARAD_1K_0902.mat
53 | :
54 | |--ARAD_1K_0950.mat
55 | |--Valid_RGB
56 | |--ARAD_1K_0901.jpg
57 | |--ARAD_1K_0902.jpg
58 | :
59 | |--ARAD_1K_0950.jpg
60 | ```
61 | 2. #### Data Preprocess.
62 | ```shell
63 | cd /DRCR-Net-master/train/
64 |
65 | # Getting the prepared train data by run:
66 | python train_data_preprocess.py --data_path '../Dataset' --patch_size 128 --stride 64 --train_data_path './dataset/Train'
67 |
68 | # Getting the prepared valid data by run:
69 | python valid_data_preprocess.py --data_path '../Dataset' --valid_data_path './dataset/Valid'
70 | ```
71 | 3. #### Training.
72 | ```shell
73 | python main.py
74 | ```
75 | The data generated during training will be recorded in `/RealWorldResults/`.
76 | ## Test
77 | ```shell
78 | cd /DRCR-Net-master/test/
79 | python test.py --RGB_dir '../Dataset/Valid_RGB' --model_dir './model/model.pth' --result_dir './test_results'
80 |
81 | # The MRAE and RMSE indicators can be obtained by run:
82 | python compute_mrae.py --path_rec './test_results' --path_gt '../Dataset/Valid_spectral'
83 | ```
84 | - Download the model ([Google Drive](https://drive.google.com/file/d/1UJfP6cw9b1EWCGHPnsGEV8AYlr9JTVYC/view?usp=sharing) / [Baidu Disk](https://pan.baidu.com/s/1rHc80ZRg7m893_hCObYAlQ), code: `drcr`))
85 | - Download the reconstructed valid spectral images ([Google Drive](https://drive.google.com/file/d/1gdF-W4OkKN7Z345ayWzsOuaBMh0p5lcm/view?usp=sharing) / [Baidu Disk](https://pan.baidu.com/s/1Wd3NQfVp4bA_IBMT5dayhg), code: `drcr`))
86 |
87 | ## Citation
88 | If you find this code helpful, please kindly cite:
89 | ```shell
90 | # DRCR Net
91 | @InProceedings{Li_2022_CVPR,
92 | author = {Li, Jiaojiao and Du, Songcheng and Wu, Chaoxiong and Leng, Yihong and Song, Rui and Li, Yunsong},
93 | title = {DRCR Net: Dense Residual Channel Re-Calibration Network With Non-Local Purification for Spectral Super Resolution},
94 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
95 | month = {June},
96 | year = {2022},
97 | pages = {1259-1268}
98 | }
99 | ```
100 | ## CVPRW poster
101 | 
102 |
--------------------------------------------------------------------------------
/figure/CRM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/figure/CRM.png
--------------------------------------------------------------------------------
/figure/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/figure/network.png
--------------------------------------------------------------------------------
/figure/poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/figure/poster.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Environment requirement
2 | - Anaconda3
3 | - pytorch1.9.0
4 | - hdf5storage
5 | - opencv-python
6 | - h5py
7 | - tqdm
--------------------------------------------------------------------------------
/test/DRCR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import warnings
5 | warnings.filterwarnings("ignore")
6 |
7 |
8 | class Conv3x3(nn.Module):
9 | def __init__(self, in_dim, out_dim, kernel_size, stride, dilation=1):
10 | super(Conv3x3, self).__init__()
11 | reflect_padding = int(dilation * (kernel_size - 1) / 2)
12 | self.reflection_pad = nn.ReflectionPad2d(reflect_padding)
13 | self.conv2d = nn.Conv2d(in_dim, out_dim, kernel_size, stride, dilation=dilation, bias=False)
14 |
15 | def forward(self, x):
16 | out = self.reflection_pad(x)
17 | out = self.conv2d(out)
18 | return out
19 |
20 | class Conv2D(nn.Module):
21 | def __init__(self, in_channel=256, out_channel=8):
22 | super(Conv2D, self).__init__()
23 | self.guide_conv2D = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
24 |
25 | def forward(self, x):
26 | spatial_guidance = self.guide_conv2D(x)
27 | return spatial_guidance
28 |
29 | class Conv2dLayer(nn.Module):
30 | def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False):
31 | super(Conv2dLayer, self).__init__()
32 | # Initialize the padding scheme
33 | if pad_type == 'reflect':
34 | self.pad = nn.ReflectionPad2d(padding)
35 | elif pad_type == 'replicate':
36 | self.pad = nn.ReplicationPad2d(padding)
37 | elif pad_type == 'zero':
38 | self.pad = nn.ZeroPad2d(padding)
39 | else:
40 | assert 0, "Unsupported padding type: {}".format(pad_type)
41 |
42 | # Initialize the normalization type
43 | if norm == 'bn':
44 | self.norm = nn.BatchNorm2d(out_channels)
45 | elif norm == 'in':
46 | self.norm = nn.InstanceNorm2d(out_channels)
47 | elif norm == 'none':
48 | self.norm = None
49 | else:
50 | assert 0, "Unsupported normalization: {}".format(norm)
51 |
52 | # Initialize the activation funtion
53 | if activation == 'relu':
54 | self.activation = nn.ReLU(inplace = True)
55 | elif activation == 'lrelu':
56 | self.activation = nn.LeakyReLU(0.2, inplace = True)
57 | elif activation == 'prelu':
58 | self.activation = nn.PReLU()
59 | elif activation == 'selu':
60 | self.activation = nn.SELU(inplace = True)
61 | elif activation == 'tanh':
62 | self.activation = nn.Tanh()
63 | elif activation == 'sigmoid':
64 | self.activation = nn.Sigmoid()
65 | elif activation == 'none':
66 | self.activation = None
67 | else:
68 | assert 0, "Unsupported activation: {}".format(activation)
69 |
70 | # Initialize the convolution layers
71 | if sn:
72 | pass
73 | else:
74 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
75 | def forward(self, x):
76 | x = self.pad(x)
77 | x = self.conv2d(x)
78 | if self.norm:
79 | x = self.norm(x)
80 | if self.activation:
81 | x = self.activation(x)
82 | return x
83 |
84 | class NPM(nn.Module):
85 | def __init__(self, in_channel):
86 | super(NPM, self).__init__()
87 | self.in_channel = in_channel
88 | self.activation = nn.LeakyReLU(0.2, inplace = True)
89 | self.conv0_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1)
90 | self.conv0_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0)
91 | self.conv_0_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1)
92 | self.conv2_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1)
93 | self.conv2_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0)
94 | self.conv_2_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1)
95 | self.conv4_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1)
96 | self.conv4_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0)
97 | self.conv_4_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1)
98 |
99 | self.conv_cat = nn.Conv2d(in_channel*3, in_channel, 3, 1, 1)
100 |
101 | def forward(self, x):
102 |
103 | x_0 = x
104 | x_2 = F.avg_pool2d(x, 2, 2)
105 | x_4 = F.avg_pool2d(x_2, 2, 2)
106 |
107 | x_0 = torch.cat([self.conv0_33(x_0), self.conv0_11(x_0)], 1)
108 | x_0 = self.activation(self.conv_0_cat(x_0))
109 |
110 | x_2 = torch.cat([self.conv2_33(x_2), self.conv2_11(x_2)], 1)
111 | x_2 = F.interpolate(self.activation(self.conv_2_cat(x_2)), scale_factor=2, mode='bilinear')
112 |
113 | x_4 = torch.cat([self.conv2_33(x_4), self.conv2_11(x_4)], 1)
114 | x_4 = F.interpolate(self.activation(self.conv_4_cat(x_4)), scale_factor=4, mode='bilinear')
115 |
116 | x = x + self.activation(self.conv_cat(torch.cat([x_0, x_2, x_4], 1)))
117 | return x
118 |
119 | class CRM(nn.Module):
120 | def __init__(self, channel, reduction = 8):
121 | super(CRM, self).__init__()
122 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
123 | self.fc = nn.Sequential(
124 | nn.Linear(channel, channel // reduction, bias = False),
125 | nn.ReLU(inplace = True),
126 | nn.Linear(channel // reduction, channel // reduction, bias = False),
127 | nn.ReLU(inplace = True),
128 | nn.Linear(channel // reduction, channel, bias = False),
129 | nn.Sigmoid()
130 | )
131 |
132 | def forward(self, x):
133 | b, c, _, _ = x.size()
134 | y = self.avg_pool(x).view(b, c)
135 | y = self.fc(y).view(b, c, 1, 1)
136 | return x * y.expand_as(x)
137 |
138 | class DRCR_Block(nn.Module):
139 | def __init__(self, in_channels, latent_channels, kernel_size = 3, stride = 1, padding = 1, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False):
140 | super(DRCR_Block, self).__init__()
141 | # dense convolutions
142 | self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type,
143 | activation, norm, sn)
144 | self.conv2 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type,
145 | activation, norm, sn)
146 | self.conv3 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type,
147 | activation, norm, sn)
148 | self.conv4 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type,
149 | activation, norm, sn)
150 | self.conv5 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type,
151 | activation, norm, sn)
152 | self.conv6 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type,
153 | activation, norm, sn)
154 | # self.cspn2_guide = GMLayer(in_channels)
155 | # self.cspn2 = Affinity_Propagate_Channel()
156 | self.se1 = CRM(in_channels)
157 | self.se2 = CRM(in_channels)
158 |
159 | def forward(self, x):
160 | x1 = self.conv1(x)
161 | x2 = self.conv2(x1)
162 | x3 = self.conv3(x2)
163 | # guidance2 = self.cspn2_guide(x3)
164 | # x3_2 = self.cspn2(guidance2, x3)
165 | x3_2 = self.se1(x)
166 | x4 = self.conv4(torch.cat((x3, x3_2), 1))
167 | x5 = self.conv5(torch.cat((x2, x4), 1))
168 | x6 = self.conv6(torch.cat((x1, x5), 1))+self.se2(x3_2)
169 | return x6
170 |
171 | class DRCR(nn.Module):
172 | def __init__(self, inplanes=3, planes=31, channels=200, n_DRBs=8):
173 | super(DRCR, self).__init__()
174 | self.input_conv2D = Conv3x3(inplanes, channels, 3, 1)
175 | self.input_prelu2D = nn.PReLU()
176 | self.head_conv2D = Conv3x3(channels, channels, 3, 1)
177 | self.denosing = NPM(channels)
178 | self.backbone = nn.ModuleList(
179 | [DRCR_Block(channels, channels) for _ in range(n_DRBs)])
180 | self.tail_conv2D = Conv3x3(channels, channels, 3, 1)
181 | self.output_prelu2D = nn.PReLU()
182 | self.output_conv2D = Conv3x3(channels, planes, 3, 1)
183 |
184 | def forward(self, x):
185 | out = self.DRN2D(x)
186 | return out
187 |
188 | def DRN2D(self, x):
189 | out = self.input_prelu2D(self.input_conv2D(x))
190 | out = self.head_conv2D(out)
191 | out = self.denosing(out)
192 |
193 | for i, block in enumerate(self.backbone):
194 | out = block(out)
195 |
196 | out = self.tail_conv2D(out)
197 | out = self.output_conv2D(self.output_prelu2D(out))
198 | return out
199 |
200 |
201 |
202 |
203 |
204 |
205 | if __name__ == "__main__":
206 | # import os
207 | # os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
208 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
209 | input_tensor = torch.rand(1, 3, 128, 128)
210 | model = DRCR(3, 31, 100, 10)
211 | # model = nn.DataParallel(model).cuda()
212 | with torch.no_grad():
213 | output_tensor = model(input_tensor)
214 | print(output_tensor.size())
215 | print('Parameters number is ', sum(param.numel() for param in model.parameters()))
216 | print(torch.__version__)
217 |
218 |
219 |
220 |
221 |
222 |
--------------------------------------------------------------------------------
/test/__pycache__/AWAN.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/test/__pycache__/AWAN.cpython-38.pyc
--------------------------------------------------------------------------------
/test/__pycache__/DRCR.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/test/__pycache__/DRCR.cpython-38.pyc
--------------------------------------------------------------------------------
/test/__pycache__/DRCR_Net.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/test/__pycache__/DRCR_Net.cpython-38.pyc
--------------------------------------------------------------------------------
/test/compute_mrae.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import hdf5storage as hdf5
4 | import numpy as np
5 | import argparse
6 |
7 |
8 | def compute_MRAE(gt, rec):
9 | gt_hyper = gt
10 | rec_hyper = rec
11 | error = np.abs(rec_hyper - gt_hyper) / gt_hyper
12 | mrae = np.mean(error.reshape(-1))
13 | return mrae
14 |
15 | def compute_RMSE(gt, rec):
16 | error = np.power(gt - rec, 2)
17 | rmse = np.sqrt(np.mean(error))
18 | return rmse
19 |
20 | def main():
21 | path_rec = opt.path_rec
22 | path_gt = opt.path_gt
23 |
24 | name_rec_list = glob.glob(os.path.join(path_rec, '*.mat'))
25 | name_gt_list = glob.glob(os.path.join(path_gt, '*.mat'))
26 | name_rec_list.sort()
27 | name_gt_list.sort()
28 |
29 | mrae_all = []
30 | rmse_all = []
31 |
32 | for i in range(len(name_gt_list)):
33 | hyper_rec = hdf5.loadmat(name_rec_list[i])['cube']
34 | hyper_gt = hdf5.loadmat(name_gt_list[i])['cube']
35 | if hyper_gt.min()<= 0.:
36 | print(os.path.basename(name_gt_list[i]), end=' ')
37 | print('This file is not suitable for compute the MRAE indicator.')
38 | continue
39 | hyper_rec = np.clip(hyper_rec, 0,1)
40 | mrae = compute_MRAE(hyper_gt, hyper_rec)
41 | rmse = compute_RMSE(hyper_gt, hyper_rec)
42 | print(os.path.basename(name_gt_list[i]), end=' ')
43 | print('mrae: '+str(mrae)+', rmse: '+str(rmse))
44 | mrae_all.append(mrae)
45 | rmse_all.append(rmse)
46 | print('The average mrae is: '+str(sum(mrae_all)/len(mrae_all)))
47 | print('The average rmse is: '+str(sum(rmse_all)/len(rmse_all)))
48 |
49 | if __name__ == '__main__':
50 | parser = argparse.ArgumentParser(description="SSR_test")
51 | parser.add_argument("--path_rec", type=str, default='./test_results', help="The path of the reconstructed valid spectral data.")
52 | parser.add_argument("--path_gt", type=str, default='../Dataset/Valid_spectral', help="The path of the ground truth valid spectral data.")
53 | opt = parser.parse_args()
54 | main()
--------------------------------------------------------------------------------
/test/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import cv2
5 | from DRCR import DRCR
6 | import glob
7 | import hdf5storage as hdf5
8 | import time
9 | import argparse
10 |
11 |
12 | def get_reconstruction_gpu(input, model):
13 | """As the limited GPU memory split the input."""
14 | model.eval()
15 | var_input = input.cuda()
16 | with torch.no_grad():
17 | start_time = time.time()
18 | var_output1 = model(var_input[:,:,:-2,:])
19 | var_output2 = model(var_input[:,:,2:,:])
20 | var_output = torch.cat([var_output1, var_output2[:,:,-2:,:]], 2)
21 | end_time = time.time()
22 |
23 | return end_time-start_time, var_output.cpu()
24 |
25 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
26 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
27 |
28 | parser = argparse.ArgumentParser(description="SSR_test")
29 | parser.add_argument("--RGB_dir", type=str, default='../Dataset/Valid_RGB', help="absolute Input_RGB_path")
30 | parser.add_argument("--model_dir", type=str, default='./model/model.pth', help="absolute Model_path")
31 | parser.add_argument("--result_dir", type=str, default='./test_results', help="absolute Save_Result_path")
32 | opt = parser.parse_args()
33 |
34 | img_path = opt.RGB_dir
35 | model_path = opt.model_dir
36 | result_path = opt.result_dir
37 |
38 | var_name = 'cube'
39 | # save results
40 | if not os.path.exists(result_path):
41 | os.makedirs(result_path)
42 | model = DRCR(3, 31, 100, 10)
43 | save_point = torch.load(model_path)
44 | model_param = save_point['state_dict']
45 | model_dict = {}
46 | for k1, k2 in zip(model.state_dict(), model_param):
47 | model_dict[k1] = model_param[k2]
48 | model.load_state_dict(model_dict)
49 | model = model.cuda()
50 |
51 | img_path_name = glob.glob(os.path.join(img_path, '*.jpg'))
52 | img_path_name.sort()
53 |
54 | for i in range(len(img_path_name)):
55 | # load rgb images
56 | print(img_path_name[i].split('/')[-1])
57 | rgb = cv2.imread(img_path_name[i])
58 | rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
59 | rgb = np.float32(rgb)
60 | rgb = rgb / rgb.max()
61 | rgb = np.expand_dims(np.transpose(rgb, [2, 0, 1]), axis=0).copy()
62 | rgb = torch.from_numpy(rgb).float()
63 | use_time, temp_hyper = get_reconstruction_gpu(rgb, model)
64 | img_res = temp_hyper.numpy() * 1.0
65 | img_res = np.transpose(np.squeeze(img_res), [1, 2, 0])
66 | img_res_limits = np.minimum(img_res, 1.0)
67 | img_res_limits = np.maximum(img_res_limits, 0)
68 |
69 | mat_name = img_path_name[i].split('/')[-1][:-4] + '.mat'
70 | mat_dir = os.path.join(result_path, mat_name)
71 | hdf5.savemat(mat_dir, {var_name: img_res}, format='7.3', store_python_metadata=True)
--------------------------------------------------------------------------------
/train/DRCR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import warnings
5 | warnings.filterwarnings("ignore")
6 |
7 |
8 | class Conv3x3(nn.Module):
9 | def __init__(self, in_dim, out_dim, kernel_size, stride, dilation=1):
10 | super(Conv3x3, self).__init__()
11 | reflect_padding = int(dilation * (kernel_size - 1) / 2)
12 | self.reflection_pad = nn.ReflectionPad2d(reflect_padding)
13 | self.conv2d = nn.Conv2d(in_dim, out_dim, kernel_size, stride, dilation=dilation, bias=False)
14 |
15 | def forward(self, x):
16 | out = self.reflection_pad(x)
17 | out = self.conv2d(out)
18 | return out
19 |
20 | class Conv2D(nn.Module):
21 | def __init__(self, in_channel=256, out_channel=8):
22 | super(Conv2D, self).__init__()
23 | self.guide_conv2D = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
24 |
25 | def forward(self, x):
26 | spatial_guidance = self.guide_conv2D(x)
27 | return spatial_guidance
28 |
29 | class Conv2dLayer(nn.Module):
30 | def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False):
31 | super(Conv2dLayer, self).__init__()
32 | # Initialize the padding scheme
33 | if pad_type == 'reflect':
34 | self.pad = nn.ReflectionPad2d(padding)
35 | elif pad_type == 'replicate':
36 | self.pad = nn.ReplicationPad2d(padding)
37 | elif pad_type == 'zero':
38 | self.pad = nn.ZeroPad2d(padding)
39 | else:
40 | assert 0, "Unsupported padding type: {}".format(pad_type)
41 |
42 | # Initialize the normalization type
43 | if norm == 'bn':
44 | self.norm = nn.BatchNorm2d(out_channels)
45 | elif norm == 'in':
46 | self.norm = nn.InstanceNorm2d(out_channels)
47 | elif norm == 'none':
48 | self.norm = None
49 | else:
50 | assert 0, "Unsupported normalization: {}".format(norm)
51 |
52 | # Initialize the activation funtion
53 | if activation == 'relu':
54 | self.activation = nn.ReLU(inplace = True)
55 | elif activation == 'lrelu':
56 | self.activation = nn.LeakyReLU(0.2, inplace = True)
57 | elif activation == 'prelu':
58 | self.activation = nn.PReLU()
59 | elif activation == 'selu':
60 | self.activation = nn.SELU(inplace = True)
61 | elif activation == 'tanh':
62 | self.activation = nn.Tanh()
63 | elif activation == 'sigmoid':
64 | self.activation = nn.Sigmoid()
65 | elif activation == 'none':
66 | self.activation = None
67 | else:
68 | assert 0, "Unsupported activation: {}".format(activation)
69 |
70 | # Initialize the convolution layers
71 | if sn:
72 | pass
73 | else:
74 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
75 | def forward(self, x):
76 | x = self.pad(x)
77 | x = self.conv2d(x)
78 | if self.norm:
79 | x = self.norm(x)
80 | if self.activation:
81 | x = self.activation(x)
82 | return x
83 |
84 | class NPM(nn.Module):
85 | def __init__(self, in_channel):
86 | super(NPM, self).__init__()
87 | self.in_channel = in_channel
88 | self.activation = nn.LeakyReLU(0.2, inplace = True)
89 | self.conv0_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1)
90 | self.conv0_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0)
91 | self.conv_0_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1)
92 | self.conv2_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1)
93 | self.conv2_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0)
94 | self.conv_2_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1)
95 | self.conv4_33 = nn.Conv2d(in_channel, in_channel, 3, 1, 1)
96 | self.conv4_11 = nn.Conv2d(in_channel, in_channel, 1, 1, 0)
97 | self.conv_4_cat = nn.Conv2d(in_channel*2, in_channel, 3, 1, 1)
98 |
99 | self.conv_cat = nn.Conv2d(in_channel*3, in_channel, 3, 1, 1)
100 |
101 | def forward(self, x):
102 |
103 | x_0 = x
104 | x_2 = F.avg_pool2d(x, 2, 2)
105 | x_4 = F.avg_pool2d(x_2, 2, 2)
106 |
107 | x_0 = torch.cat([self.conv0_33(x_0), self.conv0_11(x_0)], 1)
108 | x_0 = self.activation(self.conv_0_cat(x_0))
109 |
110 | x_2 = torch.cat([self.conv2_33(x_2), self.conv2_11(x_2)], 1)
111 | x_2 = F.interpolate(self.activation(self.conv_2_cat(x_2)), scale_factor=2, mode='bilinear')
112 |
113 | x_4 = torch.cat([self.conv2_33(x_4), self.conv2_11(x_4)], 1)
114 | x_4 = F.interpolate(self.activation(self.conv_4_cat(x_4)), scale_factor=4, mode='bilinear')
115 |
116 | x = x + self.activation(self.conv_cat(torch.cat([x_0, x_2, x_4], 1)))
117 | return x
118 |
119 | class CRM(nn.Module):
120 | def __init__(self, channel, reduction = 8):
121 | super(CRM, self).__init__()
122 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
123 | self.fc = nn.Sequential(
124 | nn.Linear(channel, channel // reduction, bias = False),
125 | nn.ReLU(inplace = True),
126 | nn.Linear(channel // reduction, channel // reduction, bias = False),
127 | nn.ReLU(inplace = True),
128 | nn.Linear(channel // reduction, channel, bias = False),
129 | nn.Sigmoid()
130 | )
131 |
132 | def forward(self, x):
133 | b, c, _, _ = x.size()
134 | y = self.avg_pool(x).view(b, c)
135 | y = self.fc(y).view(b, c, 1, 1)
136 | return x * y.expand_as(x)
137 |
138 | class DRCR_Block(nn.Module):
139 | def __init__(self, in_channels, latent_channels, kernel_size = 3, stride = 1, padding = 1, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False):
140 | super(DRCR_Block, self).__init__()
141 | # dense convolutions
142 | self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type,
143 | activation, norm, sn)
144 | self.conv2 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type,
145 | activation, norm, sn)
146 | self.conv3 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type,
147 | activation, norm, sn)
148 | self.conv4 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type,
149 | activation, norm, sn)
150 | self.conv5 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type,
151 | activation, norm, sn)
152 | self.conv6 = Conv2dLayer(in_channels * 2, in_channels, kernel_size, stride, padding, dilation, pad_type,
153 | activation, norm, sn)
154 | # self.cspn2_guide = GMLayer(in_channels)
155 | # self.cspn2 = Affinity_Propagate_Channel()
156 | self.se1 = CRM(in_channels)
157 | self.se2 = CRM(in_channels)
158 |
159 | def forward(self, x):
160 | x1 = self.conv1(x)
161 | x2 = self.conv2(x1)
162 | x3 = self.conv3(x2)
163 | # guidance2 = self.cspn2_guide(x3)
164 | # x3_2 = self.cspn2(guidance2, x3)
165 | x3_2 = self.se1(x)
166 | x4 = self.conv4(torch.cat((x3, x3_2), 1))
167 | x5 = self.conv5(torch.cat((x2, x4), 1))
168 | x6 = self.conv6(torch.cat((x1, x5), 1))+self.se2(x3_2)
169 | return x6
170 |
171 | class DRCR(nn.Module):
172 | def __init__(self, inplanes=3, planes=31, channels=200, n_DRBs=8):
173 | super(DRCR, self).__init__()
174 | self.input_conv2D = Conv3x3(inplanes, channels, 3, 1)
175 | self.input_prelu2D = nn.PReLU()
176 | self.head_conv2D = Conv3x3(channels, channels, 3, 1)
177 | self.denosing = NPM(channels)
178 | self.backbone = nn.ModuleList(
179 | [DRCR_Block(channels, channels) for _ in range(n_DRBs)])
180 | self.tail_conv2D = Conv3x3(channels, channels, 3, 1)
181 | self.output_prelu2D = nn.PReLU()
182 | self.output_conv2D = Conv3x3(channels, planes, 3, 1)
183 |
184 | def forward(self, x):
185 | out = self.DRN2D(x)
186 | return out
187 |
188 | def DRN2D(self, x):
189 | out = self.input_prelu2D(self.input_conv2D(x))
190 | out = self.head_conv2D(out)
191 | out = self.denosing(out)
192 |
193 | for i, block in enumerate(self.backbone):
194 | out = block(out)
195 |
196 | out = self.tail_conv2D(out)
197 | out = self.output_conv2D(self.output_prelu2D(out))
198 | return out
199 |
200 |
201 |
202 |
203 |
204 |
205 | if __name__ == "__main__":
206 | # import os
207 | # os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
208 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
209 | input_tensor = torch.rand(1, 3, 128, 128)
210 | model = DRCR(3, 31, 100, 10)
211 | # model = nn.DataParallel(model).cuda()
212 | with torch.no_grad():
213 | output_tensor = model(input_tensor)
214 | print(output_tensor.size())
215 | print('Parameters number is ', sum(param.numel() for param in model.parameters()))
216 | print(torch.__version__)
217 |
218 |
219 |
220 |
221 |
222 |
--------------------------------------------------------------------------------
/train/__pycache__/DRCR.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/train/__pycache__/DRCR.cpython-38.pyc
--------------------------------------------------------------------------------
/train/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/train/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/train/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jojolee6513/DRCR-net/8abdff676a3fe1a28c6d60be50ca34331c39e164/train/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/train/dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import h5py
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as udata
6 | import glob
7 | import os
8 |
9 | class HyperDatasetValid(udata.Dataset):
10 | def __init__(self, mode='valid'):
11 | if mode != 'valid':
12 | raise Exception("Invalid mode!", mode)
13 | data_path = './dataset/Valid'
14 | data_names = glob.glob(os.path.join(data_path, '*.mat'))
15 | self.keys = data_names
16 | self.keys.sort()
17 |
18 | def __len__(self):
19 | return len(self.keys)
20 |
21 | def __getitem__(self, index):
22 | mat = h5py.File(self.keys[index], 'r')
23 | hyper = np.float32(np.array(mat['rad']))
24 | hyper = np.transpose(hyper, [2, 1, 0])
25 | hyper = torch.Tensor(hyper)[:,:-2,:]
26 | rgb = np.float32(np.array(mat['rgb']))
27 | rgb = np.transpose(rgb, [2, 1, 0])
28 | rgb = torch.Tensor(rgb)[:,:-2,:]
29 | mat.close()
30 | return rgb, hyper
31 |
32 |
33 | class HyperDatasetTrain(udata.Dataset):
34 | def __init__(self, mode='train'):
35 | if mode != 'train':
36 | raise Exception("Invalid mode!", mode)
37 | data_path = './dataset/Train'
38 | data_names1 = glob.glob(os.path.join(data_path, '*.mat'))
39 |
40 | self.keys = data_names1
41 | random.shuffle(self.keys)
42 | # self.keys.sort()
43 |
44 | def __len__(self):
45 | return len(self.keys)
46 |
47 | def __getitem__(self, index):
48 | mat = h5py.File(self.keys[index], 'r')
49 | hyper = np.float32(np.array(mat['rad']))
50 | hyper = np.transpose(hyper, [2, 1, 0])
51 | hyper = torch.Tensor(hyper)
52 | rgb = np.float32(np.array(mat['rgb']))
53 | rgb = np.transpose(rgb, [2, 1, 0])
54 | rgb = torch.Tensor(rgb)
55 | mat.close()
56 | return rgb, hyper
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/train/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import argparse
4 | import torch.optim as optim
5 | import torch.backends.cudnn as cudnn
6 | from torch.utils.data import DataLoader
7 | from torch.autograd import Variable
8 | import os
9 | import time
10 | import random
11 | from dataset import HyperDatasetValid, HyperDatasetTrain
12 | from DRCR import DRCR
13 | from utils import AverageMeter, initialize_logger, save_checkpoint, record_loss, Loss_train, Loss_valid
14 | from tqdm import tqdm
15 | import warnings
16 | warnings.filterwarnings('ignore')
17 |
18 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
20 |
21 | parser = argparse.ArgumentParser(description="SSR")
22 | parser.add_argument("--batchSize", type=int, default=16, help="batch size")
23 | parser.add_argument("--end_epoch", type=int, default=130+1, help="number of epochs")
24 | parser.add_argument("--init_lr", type=float, default=1e-4, help="initial learning rate")
25 | parser.add_argument("--decay_power", type=float, default=1.5, help="decay power")
26 | parser.add_argument("--max_iter", type=float, default=300000, help="max_iter") # Needs to be adjusted with the adjustment of batchSize, number of train samples
27 | parser.add_argument("--outf", type=str, default="RealWorldResults", help='path log files')
28 | opt = parser.parse_args()
29 |
30 |
31 | def main():
32 | cudnn.benchmark = True
33 | # Load dataset
34 | print("\nloading dataset ...")
35 | train_data = HyperDatasetTrain(mode='train')
36 | print("Train set samples: ", len(train_data))
37 | val_data = HyperDatasetValid(mode='valid')
38 | print("Validation set samples: ", len(val_data))
39 | # Data Loader (Input Pipeline)
40 | train_loader1 = DataLoader(dataset=train_data, batch_size=opt.batchSize, shuffle=True, num_workers=10, pin_memory=False, drop_last=True)
41 | train_loader = [train_loader1]
42 | val_loader = DataLoader(dataset=val_data, batch_size=1, shuffle=False, num_workers=2, pin_memory=False)
43 |
44 | # Model
45 | print("\nbuilding models_baseline ...")
46 | model = DRCR(3, 31, 100, 10)
47 | print('Parameters number is ', sum(param.numel() for param in model.parameters()))
48 | criterion_train = Loss_train()
49 | criterion_valid = Loss_valid()
50 | if torch.cuda.device_count() > 1:
51 | model = nn.DataParallel(model) # batchsize integer times
52 | if torch.cuda.is_available():
53 | model.cuda()
54 | criterion_train.cuda()
55 | criterion_valid.cuda()
56 |
57 | # Parameters, Loss and Optimizer
58 | start_epoch = 0
59 | iteration = 0
60 | record_val_loss = 1000
61 | optimizer = optim.Adam(model.parameters(), lr=opt.init_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
62 |
63 | # Record
64 | if not os.path.exists(opt.outf):
65 | os.makedirs(opt.outf)
66 | loss_csv = open(os.path.join(opt.outf, 'loss.csv'), 'a+')
67 | log_dir = os.path.join(opt.outf, 'train.log')
68 | logger = initialize_logger(log_dir)
69 |
70 | # Resume
71 | resume_file = ''
72 | if resume_file:
73 | if os.path.isfile(resume_file):
74 | print("=> loading checkpoint '{}'".format(resume_file))
75 | checkpoint = torch.load(resume_file)
76 | start_epoch = checkpoint['epoch']
77 | iteration = checkpoint['iter']
78 | model.load_state_dict(checkpoint['state_dict'])
79 | optimizer.load_state_dict(checkpoint['optimizer'])
80 |
81 | # Start epoch
82 | for epoch in range(start_epoch+1, opt.end_epoch):
83 | start_time = time.time()
84 | train_loss, iteration, lr = train(train_loader, model, criterion_train, optimizer, epoch, iteration, opt.init_lr, opt.decay_power)
85 | val_loss = validate(val_loader, model, criterion_valid)
86 | # Save model
87 | if torch.abs(val_loss - record_val_loss) < 0.0001 or val_loss < record_val_loss:
88 | save_checkpoint(opt.outf, epoch, iteration, model, optimizer)
89 | if val_loss < record_val_loss:
90 | record_val_loss = val_loss
91 | end_time = time.time()
92 | epoch_time = end_time - start_time
93 | print("Epoch [%02d], Iter[%06d], Time:%.9f, learning rate : %.9f, Train Loss: %.9f Test Loss: %.9f "
94 | % (epoch, iteration, epoch_time, lr, train_loss, val_loss))
95 | # save loss
96 | record_loss(loss_csv,epoch, iteration, epoch_time, lr, train_loss, val_loss)
97 | logger.info("Epoch [%02d], Iter[%06d], Time:%.9f, learning rate : %.9f, Train Loss: %.9f Test Loss: %.9f "
98 | % (epoch, iteration, epoch_time, lr, train_loss, val_loss))
99 |
100 |
101 | # Training
102 | def train(train_loader, model, criterion, optimizer, epoch, iteration, init_lr, decay_power):
103 | model.train()
104 | random.shuffle(train_loader)
105 | losses = AverageMeter()
106 | for k, train_data_loader in (enumerate(train_loader)):
107 | for i, (images, labels) in tqdm(enumerate(train_data_loader)):
108 | labels = labels.cuda()
109 | images = images.cuda()
110 | images = Variable(images)
111 | labels = Variable(labels)
112 | # Decaying Learning Rate
113 | lr = poly_lr_scheduler(optimizer, init_lr, iteration, max_iter=opt.max_iter, power=decay_power)
114 | iteration = iteration + 1
115 | # Forward + Backward + Optimize
116 | output = model(images)
117 | loss = criterion(output, labels)
118 | loss_all = loss
119 | optimizer.zero_grad()
120 | loss_all.backward()
121 | optimizer.step()
122 | losses.update(loss.data)
123 | # print('[Epoch:%02d],[Process:%d/%d],[iter:%d],lr=%.9f,train_losses.avg=%.9f'
124 | # % (epoch, k+1, len(train_loader), iteration, lr, losses.avg))
125 |
126 | return losses.avg, iteration, lr
127 |
128 |
129 | # Validate
130 | def validate(val_loader, model, criterion):
131 | model.eval()
132 | losses = AverageMeter()
133 | for i, (input, target) in enumerate(val_loader):
134 | input = input.cuda()
135 | target = target.cuda()
136 | with torch.no_grad():
137 | # compute output
138 | output = model(input)
139 | loss = criterion(output, target)
140 | # record loss
141 | losses.update(loss.data)
142 |
143 | return losses.avg
144 |
145 |
146 | # Learning rate
147 | def poly_lr_scheduler(optimizer, init_lr, iteraion, lr_decay_iter=1, max_iter=100, power=0.9):
148 | """Polynomial decay of learning rate
149 | :param init_lr is base learning rate
150 | :param iter is a current iteration
151 | :param lr_decay_iter how frequently decay occurs, default is 1
152 | :param max_iter is number of maximum iterations
153 | :param power is a polymomial power
154 |
155 | """
156 | if iteraion % lr_decay_iter or iteraion > max_iter:
157 | return optimizer
158 |
159 | lr = init_lr*(1 - iteraion/max_iter)**power
160 | for param_group in optimizer.param_groups:
161 | param_group['lr'] = lr
162 |
163 | return lr
164 |
165 |
166 | if __name__ == '__main__':
167 | main()
168 | print(torch.__version__)
169 |
--------------------------------------------------------------------------------
/train/train_data_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import numpy as np
5 | import argparse
6 | import hdf5storage as hdf5
7 | import tqdm
8 |
9 | parser = argparse.ArgumentParser(description="SpectralSR")
10 | parser.add_argument("--data_path", type=str, default='../Dataset', help="data path")
11 | parser.add_argument("--patch_size", type=int, default=128, help="data patch size")
12 | parser.add_argument("--stride", type=int, default=64, help="data patch stride")
13 | parser.add_argument("--train_data_path", type=str, default='./dataset/Train', help="preprocess_data_path")
14 | opt = parser.parse_args()
15 |
16 |
17 | def main():
18 | if not os.path.exists(opt.train_data_path):
19 | os.makedirs(opt.train_data_path)
20 |
21 | process_data(patch_size=opt.patch_size, stride=opt.stride, mode='train')
22 |
23 |
24 | def normalize(data, max_val, min_val):
25 | return (data-min_val)/(np.float32(max_val-min_val))
26 |
27 |
28 | def Im2Patch(img, win, stride=1):
29 | k = 0
30 | endc = img.shape[0]
31 | endw = img.shape[1]
32 | endh = img.shape[2]
33 | patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
34 | TotalPatNum = patch.shape[1] * patch.shape[2]
35 | Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
36 | for i in range(win):
37 | for j in range(win):
38 | patch = img[:, i:endw-win+i+1:stride, j:endh-win+j+1:stride]
39 | Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
40 | k = k + 1
41 | return Y.reshape([endc, win, win, TotalPatNum])
42 |
43 |
44 | def process_data(patch_size, stride, mode):
45 | if mode == 'train':
46 | print("\nprocess training set ...\n")
47 | patch_num = 1
48 | filenames_hyper = glob.glob(os.path.join(opt.data_path, 'Train_spectral', '*.mat'))
49 | filenames_rgb = glob.glob(os.path.join(opt.data_path, 'Train_RGB', '*.jpg'))
50 | filenames_hyper.sort()
51 | filenames_rgb.sort()
52 | print(len(filenames_rgb))
53 | # for k in range(1): # make small dataset
54 | for k in tqdm.tqdm(range(len(filenames_rgb))):
55 | print([filenames_rgb[k][-16:]])
56 |
57 | mat = hdf5.loadmat(filenames_hyper[k])
58 | hyper = np.float32(np.array(mat['cube']))
59 | hyper = np.transpose(hyper, [2, 0, 1])
60 | if hyper.min() <= 0:
61 | print('This file contains non-positive values and is not suitable for Training!')
62 | continue
63 | hyper = normalize(hyper, max_val=1., min_val=0.)
64 | # load rgb image
65 | rgb = cv2.imread(filenames_rgb[k])
66 | rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
67 | rgb = np.transpose(rgb, [2, 0, 1])
68 | rgb = normalize(np.float32(rgb), max_val=rgb.max(), min_val=0.)
69 | # creat patches
70 | patches_hyper = Im2Patch(hyper, win=patch_size, stride=stride)
71 | patches_rgb = Im2Patch(rgb, win=patch_size, stride=stride)
72 | for j in range(patches_rgb.shape[3]):
73 | print("generate training sample #%d" % patch_num)
74 | sub_hyper = patches_hyper[:, :, :, j]
75 | sub_rgb = patches_rgb[:, :, :, j]
76 | train_data_path = os.path.join(opt.train_data_path, 'train'+str(patch_num)+'.mat')
77 | print(train_data_path)
78 | hdf5.savemat(train_data_path, {'rad': sub_hyper}, format='7.3')
79 | hdf5.savemat(train_data_path, {'rgb': sub_rgb}, format='7.3')
80 | patch_num += 1
81 |
82 | print("\ntraining set: # samples %d\n" % (patch_num-1))
83 |
84 |
85 | if __name__ == '__main__':
86 | main()
87 |
88 |
--------------------------------------------------------------------------------
/train/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import torch
4 | import torch.nn as nn
5 | import logging
6 | import numpy as np
7 | import os
8 | import hdf5storage
9 |
10 |
11 | class AverageMeter(object):
12 | def __init__(self):
13 | self.reset()
14 |
15 | def reset(self):
16 | self.val = 0
17 | self.avg = 0
18 | self.sum = 0
19 | self.count = 0
20 |
21 | def update(self, val, n=1):
22 | self.val = val
23 | self.sum += val * n
24 | self.count += n
25 | self.avg = self.sum / self.count
26 |
27 |
28 | def initialize_logger(file_dir):
29 | logger = logging.getLogger()
30 | fhandler = logging.FileHandler(filename=file_dir, mode='a')
31 | formatter = logging.Formatter('%(asctime)s - %(message)s',"%Y-%m-%d %H:%M:%S")
32 | fhandler.setFormatter(formatter)
33 | logger.addHandler(fhandler)
34 | logger.setLevel(logging.INFO)
35 | return logger
36 |
37 |
38 | def save_checkpoint(model_path, epoch, iteration, model, optimizer):
39 | state = {
40 | 'epoch': epoch,
41 | 'iter': iteration,
42 | 'state_dict': model.state_dict(),
43 | 'optimizer': optimizer.state_dict(),
44 | }
45 |
46 | torch.save(state, os.path.join(model_path, 'net_%depoch.pth' % epoch))
47 |
48 |
49 | def save_matv73(mat_name, var_name, var):
50 | hdf5storage.savemat(mat_name, {var_name: var}, format='7.3', store_python_metadata=True)
51 |
52 |
53 | def record_loss(loss_csv,epoch, iteration, epoch_time, lr, train_loss, test_loss):
54 | """ Record many results."""
55 | loss_csv.write('{},{},{},{},{},{}\n'.format(epoch, iteration, epoch_time, lr, train_loss, test_loss))
56 | loss_csv.flush()
57 | loss_csv.close
58 |
59 |
60 | class Loss_train(nn.Module):
61 | def __init__(self):
62 | super(Loss_train, self).__init__()
63 |
64 | def forward(self, outputs, label):
65 | error = torch.abs(outputs - label) / label
66 | # error = torch.abs(outputs - label)
67 | mrae = torch.mean(error.view(-1))
68 | return mrae
69 |
70 |
71 | class Loss_valid(nn.Module):
72 | def __init__(self):
73 | super(Loss_valid, self).__init__()
74 |
75 | def forward(self, outputs, label):
76 | error = torch.abs(outputs - label) / label
77 | # error = torch.abs(outputs - label)
78 | mrae = torch.mean(error.view(-1))
79 | return mrae
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/train/valid_data_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import cv2
4 | import glob
5 | import numpy as np
6 | import argparse
7 | import hdf5storage
8 |
9 | parser = argparse.ArgumentParser(description="SpectralSR")
10 | parser.add_argument("--data_path", type=str, default='../Dataset', help="data path")
11 | parser.add_argument("--valid_data_path", type=str, default='./dataset/Valid', help="preprocess_data_path")
12 |
13 |
14 | opt = parser.parse_args()
15 |
16 |
17 | def main():
18 | if not os.path.exists(opt.valid_data_path):
19 | os.makedirs(opt.valid_data_path)
20 |
21 | process_data(mode='valid')
22 |
23 |
24 | def normalize(data, max_val, min_val):
25 | return (data-min_val)/(max_val-min_val)
26 |
27 |
28 | def Im2Patch(img, win, stride=1):
29 | k = 0
30 | endc = img.shape[0]
31 | endw = img.shape[1]
32 | endh = img.shape[2]
33 | patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
34 | TotalPatNum = patch.shape[1] * patch.shape[2]
35 | Y = np.zeros([endc, win*win, TotalPatNum], np.float32)
36 | for i in range(win):
37 | for j in range(win):
38 | patch = img[:, i:endw-win+i+1:stride, j:endh-win+j+1:stride]
39 | Y[:, k, :] = np.array(patch[:]).reshape(endc, TotalPatNum)
40 | k = k + 1
41 | return Y.reshape([endc, win, win, TotalPatNum])
42 |
43 |
44 | def process_data(mode):
45 | if mode == 'valid':
46 | print("\nprocess valid set ...\n")
47 | patch_num = 1
48 | filenames_hyper = glob.glob(os.path.join(opt.data_path, 'Valid_spectral', '*.mat'))
49 | filenames_rgb = glob.glob(os.path.join(opt.data_path, 'Valid_RGB', '*.jpg'))
50 | filenames_hyper.sort()
51 | filenames_rgb.sort()
52 | # for k in range(1): # make small dataset
53 | for k in range(len(filenames_rgb)):
54 | print([filenames_rgb[k]])
55 | # load hyperspectral image
56 | mat = hdf5storage.loadmat(filenames_hyper[k])
57 | hyper = np.float32(np.array(mat['cube']))
58 | hyper = np.transpose(hyper, [2, 0, 1])
59 | if hyper.min() <= 0:
60 | print('This file contains non-positive values and is not suitable for Testing!')
61 | continue
62 | hyper = normalize(hyper, max_val=1., min_val=0.)
63 | # load rgb image
64 | rgb = cv2.imread(filenames_rgb[k])
65 | rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
66 | rgb = np.transpose(rgb, [2, 0, 1])
67 | rgb = normalize(np.float32(rgb), max_val=rgb.max(), min_val=0.)
68 | valid_data_path = os.path.join(opt.valid_data_path, 'valid' + str(patch_num) + '.mat')
69 | print(valid_data_path)
70 | hdf5storage.savemat(valid_data_path, {'rad': hyper}, format='7.3')
71 | hdf5storage.savemat(valid_data_path, {'rgb': rgb}, format='7.3')
72 | patch_num += 1
73 |
74 | print("\ntraining set: # samples %d\n" % (patch_num-1))
75 |
76 |
77 | if __name__ == '__main__':
78 | main()
79 |
80 |
--------------------------------------------------------------------------------