├── DataProcess.py ├── Dataset_Train └── README.txt ├── MASK └── README.txt ├── Measurements_Test └── README.txt ├── Model_zoo └── README.txt ├── README.md ├── SplitDataset.py ├── architecture ├── SRNet.py ├── __init__.py └── __pycache__ │ ├── SRNet.cpython-37.pyc │ └── __init__.cpython-37.pyc ├── getdataset.py ├── my_utils.py ├── test_HyperspecI_V1.py ├── test_HyperspecI_V2.py ├── train_HyperspecI_V1.py └── train_HyperspecI_V2.py /DataProcess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import random 6 | 7 | class Data_Process(object): 8 | def __init__(self): 9 | self.noise_sigma = 0 10 | self.hsi_max = [] 11 | 12 | def add_noise(self, inputs, sigma): 13 | noise = torch.zeros_like(inputs) 14 | noise.normal_(0, sigma) 15 | noisy = inputs + noise 16 | noisy = torch.clamp(noisy, 0, 1.0) 17 | return noisy 18 | 19 | #Randomly extract sub-patches required for training from the original patch 20 | def get_random_mask_patches(self, mask, image_size, patch_size, batch_size): 21 | 22 | masks = [] 23 | for i in range(batch_size): 24 | random_h = random.randint(0, image_size[0] - patch_size[0] -1) 25 | random_w = random.randint(0, image_size[1] - patch_size[1] -1) 26 | mask_patch = mask[:, random_h:random_h + patch_size[0], random_w:random_w + patch_size[1]] 27 | mask_patch = mask_patch / mask_patch.max() 28 | masks.append(mask_patch) 29 | 30 | mask_patches = torch.stack(masks, dim=0) 31 | return mask_patches 32 | 33 | 34 | #Forward model of snapshot hyperspectral imaging for generating input synthesized measurements from hyperspectral targets 35 | def get_mos_hsi(self, hsi, mask, sigma=0, mos_size=2048, hsi_input_size=512, hsi_target_size=512, init_div_rat=10): 36 | if not hsi_input_size == hsi_target_size: 37 | hsi_out = self.extend_spatial_resolution(hsi, extend_rate=hsi_target_size / hsi_input_size) 38 | else: 39 | hsi_out=hsi 40 | 41 | if not mos_size == hsi_input_size: 42 | hsi_expand = self.extend_spatial_resolution(hsi, extend_rate=mos_size / hsi_input_size) 43 | else: 44 | hsi_expand=hsi 45 | 46 | mos = torch.sum(hsi_expand * mask, dim=1).unsqueeze(1) 47 | mos_max = torch.max(mos.view(mos.shape[0], -1), 1)[0].unsqueeze(1).unsqueeze(1).unsqueeze(1) 48 | 49 | #normalize the input and target data using the adaptive variable 50 | output_hsi = hsi_out / mos_max * init_div_rat 51 | input_mos = mos / mos_max 52 | 53 | 54 | if isinstance(sigma, tuple): 55 | select_noise_sigma = sigma[random.randint(0, len(sigma) - 1)] 56 | else: 57 | select_noise_sigma = sigma 58 | 59 | input_mos = self.add_noise(input_mos, select_noise_sigma) 60 | 61 | return input_mos, output_hsi 62 | 63 | 64 | def extend_spatial_resolution(self, hsi, extend_rate): 65 | hsi_extend = torch.nn.functional.interpolate(hsi, recompute_scale_factor=True, scale_factor=extend_rate) 66 | return hsi_extend 67 | 68 | 69 | 70 | class Image_Cut(object): 71 | def __init__(self, image_size, patch_size, stride): 72 | self.patch_size = patch_size 73 | self.stride = stride 74 | self.image_size = image_size 75 | 76 | self.patch_number = [] 77 | self.hsi_max = [] 78 | 79 | def image2patch(self, image): 80 | ''' 81 | image_size = C, H, W 82 | ''' 83 | patch_size = self.patch_size 84 | stride = self.stride 85 | 86 | c, h, w = image.shape 87 | image = image.unsqueeze(0) 88 | range_h = np.arange(0, h-patch_size[0], stride) 89 | range_w = np.arange(0, w-patch_size[1], stride) 90 | 91 | range_h = np.append(range_h, h-patch_size[0]) 92 | range_w = np.append(range_w, w-patch_size[1]) 93 | patches = [] 94 | for m in range_h: 95 | for n in range_w: 96 | patches.append(image[:, :, m : m + patch_size[0], n : n + patch_size[1]]) 97 | 98 | return torch.cat(patches, 0) 99 | def patch2image(self, patches): 100 | 101 | patch_size = self.patch_size 102 | stride = self.stride 103 | c = patches.shape[1] 104 | h, w = self.image_size 105 | 106 | res = torch.zeros((c, h, w)).to(patches.device) 107 | weight = torch.zeros((c, h, w)).to(patches.device) 108 | 109 | range_h = np.arange(0, h-patch_size[0], stride) 110 | range_w = np.arange(0, w-patch_size[1], stride) 111 | 112 | 113 | range_h = np.append(range_h, h-patch_size[0]) 114 | range_w = np.append(range_w, w-patch_size[1]) 115 | 116 | index = 0 117 | 118 | for m in range_h: 119 | for n in range_w: 120 | res[:, m : m + patch_size[0], n : n + patch_size[1]] = res[:, m : m + patch_size[0], n : n + patch_size[1]] + patches[index, ...] 121 | 122 | weight[:, m : m + patch_size[0], n : n + patch_size[1]] = weight[:, m : m + patch_size[0], n : n + patch_size[1]] + 1 123 | index = index+1 124 | 125 | image = res / weight 126 | return image 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /Dataset_Train/README.txt: -------------------------------------------------------------------------------- 1 | The training datase can be download from: https://github.com/bianlab/Hyperspectral-imaging-dataset -------------------------------------------------------------------------------- /MASK/README.txt: -------------------------------------------------------------------------------- 1 | Download form google drive: https://drive.google.com/drive/folders/1x6nZpcTP9RIsENJL566pV9v83e1e4gpn?usp=sharing 2 | 3 | -------------------------------------------------------------------------------- /Measurements_Test/README.txt: -------------------------------------------------------------------------------- 1 | Download form google drive: https://drive.google.com/drive/folders/1x6nZpcTP9RIsENJL566pV9v83e1e4gpn?usp=sharing -------------------------------------------------------------------------------- /Model_zoo/README.txt: -------------------------------------------------------------------------------- 1 | Download form google drive: https://drive.google.com/drive/folders/1x6nZpcTP9RIsENJL566pV9v83e1e4gpn?usp=sharing -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A broadband hyperspectral image sensor with high spatio-temporal resolution 2 | 3 | [Liheng Bian*](https://scholar.google.com/citations?user=66IFMDEAAAAJ&hl=zh-CN&oi=sra), [Zhen Wang*](https://scholar.google.com/citations?hl=zh-CN&user=DexiDloAAAAJ), [Yuzhe Zhang*](https://scholar.google.com/citations?hl=zh-CN&user=rymYR-wAAAAJ), Lianjie Li, Yinuo Zhang, Chen Yang, Wen Fang, Jiajun Zhao, Chunli Zhu, Qinghao Meng, Xuan Peng, and Jun Zhang. (*Equal contributions) 4 | 5 | 6 | 7 | ## 1. System requirements 8 | 9 | ### 1.1 All software dependencies and operating systems 10 | 11 | The project has been tested on Windows 10 or Ubuntu 20.04.1. 12 | 13 | ### 1.2 Versions the software has been tested on 14 | 15 | The project has been tested on CUDA 11.4, pytorch 1.11.0, torchvision 0.12.0, python 3.7.13, opencv-python 4.5.5.64. 16 | 17 | 18 | 19 | ## 2. Installation guide 20 | 21 | ### 2.1 Instructions 22 | 23 | - The code for training and testing can be downloaded at public repository :https://github.com/bianlab/HyperspecI 24 | - The mask, testing measurements and pre-trained weights can be downloaded from the Google Drive link: https://drive.google.com/drive/folders/1x6nZpcTP9RIsENJL566pV9v83e1e4gpn?usp=sharing 25 | - Due to the massive amount of training dataset, we have packaged it into multiple repositories for storage: https://github.com/bianlab/Hyperspectral-imaging-dataset 26 | 27 | 28 | 29 | ## 3. Program description and testing 30 | 31 | Download the mask to `./MASK/HyperspecI_V1.mat` and `./MASK/HyperspecI_V2.mat` ; 32 | 33 | Download the pre-trained weights to `./model_zoo/SRNet_V1.pth` and `./model_zoo/SRNet_V2.pth` ; 34 | 35 | Download the testing measurements to `./Measurements_Test/HyperspecI_V1/` and `./Measurements_Test/HyperspecI_V2/` 36 | 37 | Download the training dataset to `'./Dataset_Train/HSI_400_1000/HSI_all/'` and `'./Dataset_Train/HSI_400_1700/HSI_all/'` 38 | 39 | ### 3.1 Main program and data description 40 | 41 | - The model of hyperspectral images reconstruction: `./architecture/SRNet.py` 42 | 43 | - Pre-trained weights of SRNet for HyperspecI-V1: `./model_zoo/SRNet_V1.pth` 44 | 45 | - Pre-trained weights of SRNet for HyperspecI-V2: `./model_zoo/SRNet_V2.pth` 46 | 47 | - Calibrated sensing matrix of HyperspecI-V1: `./MASK/HyperspecI_V1.mat` 48 | 49 | - Calibrated sensing matrix of HyperspecI-V2: `./MASK/HyperspecI_V2.mat` 50 | 51 | - Measurements collected by our HyperspecI-V1: `./Measurements_Test/HyperspecI_V1/` 52 | 53 | - Measurements collected by our HyperspecI-V2: `./Measurements_Test/HyperspecI_V2/` 54 | 55 | - The test and training program : `train_HyperspecI_V1.py` ,`train_HyperspecI_V2.py` `test_HyperspecI_V1.py` ,`test_HyperspecI_V2.py` 56 | 57 | 58 | 59 | ### 3.2 Model Training of SRNet 60 | 61 | Run the train program on the collected measurements to reconstruct hyperspectral images in pytorch platform. 62 | 63 | ● First, download the training dataset of HyperspecI-V1 (400-1000 nm ) into ` ./Dataset_Train/HSI_400_1000/HSI_all/` , and the training dataset of HyperspecI-V2 (400-1700 nm ) into ` ./Dataset_Train/HSI_400_1700/HSI_all/` . 64 | 65 | ● Second, run `SplitDataset.py` to partition the training data and validate, with 90% allocated for training and 10% for validation. 66 | 67 | The details operations for HyperspecI-V1 dataset partition : 68 | 69 | ```python 70 | python SplitDataset.py --data_folder './Dataset_Train/HSI_400_1000/HSI_all/' --train_folder './Dataset_Train/HSI_400_1000/Train/' --test_folder './Dataset_Train/HSI_400_1000/Valid/' 71 | ``` 72 | 73 | The details operations for HyperspecI-V2 dataset partition : 74 | 75 | ```python 76 | python SplitDataset.py --data_folder './Dataset_Train/HSI_400_1700/HSI_all/' --train_folder './Dataset_Train/HSI_400_1700/Train/' --test_folder './Dataset_Train/HSI_400_1700/Valid/' 77 | ``` 78 | 79 | 80 | 81 | ● Third, the training programs are executed to train the spectral reconstruction model. 82 | 83 | For training HyperspecI-V1, execute the following command in the terminal, and the training results will be saved in the ` ./exp/HyperspecI_V1/` folder. 84 | 85 | ```python 86 | python train_HyperspecI_V1.py 87 | ``` 88 | 89 | For training HyperspecI-V2, execute the following command in the terminal, and the training results will be saved in the ` ./exp/HyperspecI_V2/` folder. 90 | 91 | ```python 92 | python train_HyperspecI_V2.py 93 | ``` 94 | 95 | 96 | 97 | ### 3.3 Test hyperspectral reconstruction results in real-world scenes 98 | 99 | Run the test program on the collected images to reconstruct hyperspectral images in pytorch platform. 100 | 101 | (1) When the images were collected using our HyperspecI-V1 imaging sensors, the hypersepectral images can be reconstructed by run the following program in the terminal. 102 | 103 | ```python 104 | python test_HyperspecI_V1.py 105 | ``` 106 | 107 | The measurements collected using HyperspecI-V1 from the folder `'./Measurements_Test/HyperspecI_V1/' ` . And output reconstructed hyperspectral images will be saved in `'./Measurements_Test/Output_HyperspecI_V1/' ` . 108 | 109 | 110 | 111 | (2) When the images were collected using our HyperspecI-V2 imaging sensors, the hypersepectral images can be reconstructed by run the following program in the terminal. 112 | 113 | ```python 114 | python test_HyperspecI_V2.py 115 | ``` 116 | 117 | The measurements collected using HyperspecI-V2 from the folder `'./Measurements_Test/HyperspecI_V2/' ` . And output reconstructed hyperspectral images will be saved in `'./Measurements_Test/Output_HyperspecI_V2/' ` . 118 | -------------------------------------------------------------------------------- /SplitDataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import matplotlib.pyplot as plt 4 | import matplotlib.image as mp 5 | import scipy.io as scio 6 | import os 7 | import cv2 8 | import time 9 | import random 10 | import shutil 11 | import argparse 12 | 13 | 14 | parser = argparse.ArgumentParser(description="Training data and validata partition") 15 | 16 | #Paths of input and output data 17 | parser.add_argument("--data_folder", type=str, default= './Dataset_Train/HSI_400_1000/HSI_all/', help='Original data folder') 18 | 19 | parser.add_argument("--train_folder", type=str, default= './Dataset_Train/HSI_400_1000/Train/', help='Training data folder') 20 | parser.add_argument("--test_folder", type=str, default= './Dataset_Train/HSI_400_1000/Valid/', help='Validata folder') 21 | 22 | opt = parser.parse_args() 23 | 24 | 25 | path_data = os.listdir(opt.data_folder) 26 | 27 | random.shuffle(path_data) 28 | train_ratio = 0.9 #the ratio of training data 29 | data_nums = len(path_data) 30 | train_nums = int(data_nums * train_ratio) 31 | train_sample = random.sample(path_data,train_nums) 32 | test_sample = list(set(path_data)-set(train_sample)) 33 | print(len(path_data)) 34 | print(len(train_sample)) 35 | print(len(test_sample)) 36 | 37 | 38 | #Move the original data into the taining data folder 39 | for k in train_sample: 40 | shutil.move(os.path.join(opt.data_folder,k),os.path.join(opt.train_folder,k)) 41 | 42 | 43 | # Move the original data into the validata folder 44 | for k in test_sample: 45 | 46 | shutil.move(os.path.join(opt.data_folder,k),os.path.join(opt.test_folder,k)) 47 | 48 | -------------------------------------------------------------------------------- /architecture/SRNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | import math 6 | import warnings 7 | from torch.nn.init import _calculate_fan_in_and_fan_out 8 | import numbers 9 | 10 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 11 | def norm_cdf(x): 12 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 13 | 14 | if (mean < a - 2 * std) or (mean > b + 2 * std): 15 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 16 | "The distribution of values may be incorrect.", 17 | stacklevel=2) 18 | with torch.no_grad(): 19 | l = norm_cdf((a - mean) / std) 20 | u = norm_cdf((b - mean) / std) 21 | tensor.uniform_(2 * l - 1, 2 * u - 1) 22 | tensor.erfinv_() 23 | tensor.mul_(std * math.sqrt(2.)) 24 | tensor.add_(mean) 25 | tensor.clamp_(min=a, max=b) 26 | return tensor 27 | 28 | 29 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 30 | # type: (Tensor, float, float, float, float) -> Tensor 31 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 32 | 33 | 34 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): 35 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 36 | if mode == 'fan_in': 37 | denom = fan_in 38 | elif mode == 'fan_out': 39 | denom = fan_out 40 | elif mode == 'fan_avg': 41 | denom = (fan_in + fan_out) / 2 42 | variance = scale / denom 43 | if distribution == "truncated_normal": 44 | trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) 45 | elif distribution == "normal": 46 | tensor.normal_(std=math.sqrt(variance)) 47 | elif distribution == "uniform": 48 | bound = math.sqrt(3 * variance) 49 | tensor.uniform_(-bound, bound) 50 | else: 51 | raise ValueError(f"invalid distribution {distribution}") 52 | 53 | 54 | def lecun_normal_(tensor): 55 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') 56 | 57 | 58 | 59 | class GELU(nn.Module): 60 | def forward(self, x): 61 | return F.gelu(x) 62 | 63 | 64 | class Spectral_Atten(nn.Module): 65 | def __init__(self, dim, heads): 66 | super().__init__() 67 | self.num_heads = heads 68 | self.to_q = nn.Conv2d(dim, dim, kernel_size=1, bias=False) 69 | self.to_k = nn.Conv2d(dim, dim, kernel_size=1, bias=False) 70 | self.to_v = nn.Conv2d(dim, dim, kernel_size=1, bias=False) 71 | 72 | self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=False) 73 | self.k_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=False) 74 | self.v_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=False) 75 | self.rescale = nn.Parameter(torch.ones(heads, 1, 1)) 76 | self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=False) 77 | def forward(self, x_in): 78 | """ 79 | x_in: [b,h,w,c] 80 | return out: [b,h,w,c] 81 | """ 82 | b, c, h, w = x_in.shape 83 | q_in = self.q_dwconv(self.to_q(x_in)) 84 | k_in = self.k_dwconv(self.to_k(x_in)) 85 | v_in = self.v_dwconv(self.to_v(x_in)) 86 | 87 | q = rearrange(q_in, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 88 | k = rearrange(k_in, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 89 | v = rearrange(v_in, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 90 | 91 | q = F.normalize(q, dim=-1, p=2) 92 | k = F.normalize(k, dim=-1, p=2) 93 | atten = (q @ k.transpose(-2, -1)) * self.rescale 94 | atten = atten.softmax(dim=-1) 95 | out = (atten @ v) 96 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 97 | out = self.proj(out) 98 | 99 | return out 100 | 101 | 102 | class WithBias_LayerNorm(nn.Module): 103 | def __init__(self, normalized_shape): 104 | super(WithBias_LayerNorm, self).__init__() 105 | if isinstance(normalized_shape, numbers.Integral): 106 | normalized_shape = (normalized_shape,) 107 | normalized_shape = torch.Size(normalized_shape) 108 | 109 | assert len(normalized_shape) == 1 110 | 111 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 112 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 113 | self.normalized_shape = normalized_shape 114 | 115 | def forward(self, x): 116 | mu = x.mean(-1, keepdim=True) 117 | sigma = x.var(-1, keepdim=True, unbiased=False) 118 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 119 | def to_3d(x): 120 | return rearrange(x, 'b c h w -> b (h w) c') 121 | 122 | def to_4d(x,h,w): 123 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 124 | 125 | class LayerNorm(nn.Module): 126 | def __init__(self, dim): 127 | super(LayerNorm, self).__init__() 128 | self.body = WithBias_LayerNorm(dim) 129 | 130 | def forward(self, x): 131 | h, w = x.shape[-2:] 132 | return to_4d(self.body(to_3d(x)), h, w) 133 | 134 | 135 | class PreNorm(nn.Module): 136 | def __init__(self, dim, mult=4): 137 | super().__init__() 138 | self.net1 = nn.Sequential( 139 | nn.Conv2d(dim, dim, 1, 1, bias=False), 140 | GELU(), 141 | nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim), 142 | ) 143 | self.net2 = nn.Sequential( 144 | nn.Conv2d(dim, dim, 1, 1, bias=False), 145 | GELU(), 146 | nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim), 147 | ) 148 | self.relu = nn.GELU() 149 | self.out_conv = nn.Conv2d(dim * 2, dim, 1, 1, bias=False) 150 | # self.norm = nn.LayerNorm(dim) 151 | 152 | def forward(self, x): 153 | out1 = self.net1(x) 154 | out2 = self.net2(x) 155 | out = torch.cat((out1, out2), dim=1) 156 | return self.out_conv(self.relu(out)) 157 | 158 | class SAM_Spectral(nn.Module): 159 | def __init__(self, dim, heads, num_blocks): 160 | super().__init__() 161 | 162 | self.blocks = nn.ModuleList([]) 163 | for _ in range(num_blocks): 164 | self.blocks.append(nn.ModuleList([ 165 | LayerNorm(dim), 166 | Spectral_Atten(dim=dim, heads=heads), 167 | LayerNorm(dim), 168 | PreNorm(dim, mult=4) 169 | ])) 170 | def forward(self, x): 171 | """ 172 | x: [b,c,h,w] 173 | return out: [b,c,h,w] 174 | """ 175 | for (norm1, atten, norm2, ffn) in self.blocks: 176 | x = atten(norm1(x)) + x 177 | x = ffn(norm2(x)) + x 178 | return x 179 | 180 | class SRNet(nn.Module): 181 | def __init__(self, in_channels=1, out_channels=61, dim=32, deep_stage=3, num_blocks=[1, 1, 1], num_heads=[1, 2, 4]): 182 | super(SRNet, self).__init__() 183 | self.dim = dim 184 | self.out_channels = out_channels 185 | self.stage = deep_stage 186 | 187 | self.embedding1 = nn.Conv2d(in_channels, dim, kernel_size=3, padding=1, bias=False) 188 | self.embedding2 = nn.Conv2d(out_channels, dim, kernel_size=3, padding=1, bias=False) 189 | self.embedding = nn.Conv2d(dim * 2, dim, kernel_size=3, padding=1, bias=False) 190 | 191 | self.down_sample = nn.Conv2d(dim, dim, 4, 2, 1, bias=False) 192 | self.up_sample = nn.ConvTranspose2d(dim, dim, stride=2, kernel_size=2, padding=0, output_padding=0) 193 | 194 | self.mapping = nn.Conv2d(dim, out_channels, kernel_size=3, padding=1, bias=False) 195 | 196 | 197 | self.encoder_layers = nn.ModuleList([]) 198 | dim_stage = dim 199 | for i in range(deep_stage): 200 | self.encoder_layers.append(nn.ModuleList([ 201 | SAM_Spectral(dim=dim_stage, heads=num_heads[i], num_blocks=num_blocks[i]), 202 | nn.Conv2d(dim_stage, dim_stage * 2, 4, 2, 1, bias=False), 203 | ])) 204 | dim_stage *= 2 205 | 206 | 207 | self.bottleneck = SAM_Spectral( 208 | dim=dim_stage, heads=num_heads[-1], num_blocks=num_blocks[-1]) 209 | 210 | self.decoder_layers = nn.ModuleList([]) 211 | for i in range(deep_stage): 212 | self.decoder_layers.append(nn.ModuleList([ 213 | nn.ConvTranspose2d(dim_stage, dim_stage // 2, stride=2, kernel_size=2, padding=0, output_padding=0), 214 | nn.Conv2d(dim_stage, dim_stage // 2, 1, 1, bias=False), 215 | SAM_Spectral(dim=dim_stage // 2, heads=num_heads[deep_stage - 1 - i], num_blocks=num_blocks[deep_stage - 1 - i]), 216 | ])) 217 | dim_stage //= 2 218 | 219 | self.apply(self._init_weights) 220 | 221 | def _init_weights(self, m): 222 | if isinstance(m, nn.Linear): 223 | trunc_normal_(m.weight, std=.02) 224 | if isinstance(m, nn.Linear) and m.bias is not None: 225 | nn.init.constant_(m.bias, 0) 226 | elif isinstance(m, nn.LayerNorm): 227 | nn.init.constant_(m.bias, 0) 228 | nn.init.constant_(m.weight, 1.0) 229 | 230 | def forward(self, x, mask): 231 | """ 232 | x: [b,c,h,w] 233 | return out:[b,c,h,w] 234 | """ 235 | 236 | x = self.embedding1(x) 237 | mask = self.embedding2(mask) 238 | x = torch.cat((x, mask), dim=1) 239 | 240 | fea = self.embedding(x) 241 | residual = fea 242 | fea = self.down_sample(fea) 243 | 244 | fea_encoder = [] 245 | for (Attention, FeaDownSample) in self.encoder_layers: 246 | fea = Attention(fea) 247 | fea_encoder.append(fea) 248 | fea = FeaDownSample(fea) 249 | 250 | 251 | fea = self.bottleneck(fea) 252 | 253 | for i, (FeaUpSample, Fution, Attention) in enumerate(self.decoder_layers): 254 | fea = FeaUpSample(fea) 255 | fea = Fution(torch.cat([fea, fea_encoder[self.stage - 1 - i]], dim=1)) 256 | fea = Attention(fea) 257 | 258 | fea = self.up_sample(fea) 259 | out = fea + residual 260 | out = self.mapping(out) 261 | 262 | return out 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | -------------------------------------------------------------------------------- /architecture/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .SRNet import SRNet 3 | 4 | def model_generator(method, pretrained_model_path=None): 5 | 6 | if method == 'V1_srnet': 7 | model = SRNet(in_channels=1, out_channels=61, dim=32, deep_stage=3, num_blocks=[1, 1, 1, 1], num_heads=[1, 2, 4, 8]).cuda() 8 | 9 | elif method == 'V2_srnet': 10 | model = SRNet(in_channels=1, out_channels=96, dim=32, deep_stage=3, num_blocks=[1, 1, 1, 1], num_heads=[1, 2, 4, 8]).cuda() 11 | 12 | else: 13 | print(f'Method {method} is not defined !!!!') 14 | if pretrained_model_path is not None: 15 | print(f'load model from {pretrained_model_path}') 16 | checkpoint = torch.load(pretrained_model_path) 17 | 18 | model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}, 19 | strict=True) 20 | return model 21 | -------------------------------------------------------------------------------- /architecture/__pycache__/SRNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bianlab/HyperspecI/6e289c209483dd29d0a24cf21be32baab3e24d30/architecture/__pycache__/SRNet.cpython-37.pyc -------------------------------------------------------------------------------- /architecture/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bianlab/HyperspecI/6e289c209483dd29d0a24cf21be32baab3e24d30/architecture/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /getdataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch 5 | import random 6 | import os 7 | import hdf5storage 8 | import matplotlib.pyplot as plt 9 | import cv2 10 | from scipy import interpolate 11 | import torch.nn.functional as F 12 | import h5py 13 | 14 | 15 | class TrainDataset_V1(Dataset): 16 | def __init__(self, data_path, patch_size, arg=False): 17 | 18 | self.arg = arg 19 | self.data_path = data_path 20 | self.patch_size = patch_size 21 | 22 | data_list = os.listdir(data_path) 23 | data_list.sort() 24 | 25 | self.data_list = data_list 26 | self.img_num = len(self.data_list) 27 | 28 | def arguement(self, img, rotTimes, vFlip, hFlip): 29 | # Random rotation 30 | for j in range(rotTimes): 31 | img = np.rot90(img.copy(), axes=(1, 2)) 32 | # Random vertical Flip 33 | for j in range(vFlip): 34 | img = img[:, :, ::-1].copy() 35 | # Random horizontal Flip 36 | for j in range(hFlip): 37 | img = img[:, ::-1, :].copy() 38 | return img 39 | 40 | def __getitem__(self, idx): 41 | 42 | 43 | f = h5py.File(self.data_path + self.data_list[idx], 'r') 44 | hsi = f['hsi'][:] 45 | f.close() 46 | 47 | patch_size_h = self.patch_size[0] 48 | patch_size_w = self.patch_size[1] 49 | 50 | if self.arg: 51 | rotTimes = random.randint(0, 3) 52 | vFlip = random.randint(0, 1) 53 | hFlip = random.randint(0, 1) 54 | hsi = self.arguement(hsi, rotTimes, vFlip, hFlip) 55 | 56 | random_h = random.randint(0,hsi.shape[1] - patch_size_h -1) 57 | random_w = random.randint(0,hsi.shape[2] - patch_size_w -1) 58 | output_hsi = hsi[:, random_h:random_h+patch_size_h, random_w:random_w+patch_size_w] 59 | output_hsi = output_hsi.astype(np.float32) 60 | output_hsi = output_hsi / output_hsi.max() 61 | 62 | return np.ascontiguousarray(output_hsi) 63 | 64 | def __len__(self): 65 | return self.img_num 66 | 67 | class ValidDataset_V1(Dataset): 68 | def __init__(self, data_path, patch_size, arg=False): 69 | 70 | self.arg = arg 71 | self.data_paths = [] 72 | self.patch_size = patch_size 73 | 74 | data_list = os.listdir(data_path) 75 | data_list.sort() 76 | for i in range(len(data_list)): 77 | 78 | self.data_paths.append(data_path + data_list[i]) 79 | 80 | self.img_num = len(self.data_paths) 81 | 82 | def arguement(self, img, rotTimes, vFlip, hFlip): 83 | # Random rotation 84 | for j in range(rotTimes): 85 | img = np.rot90(img.copy(), axes=(1, 2)) 86 | # Random vertical Flip 87 | for j in range(vFlip): 88 | img = img[:, :, ::-1].copy() 89 | # Random horizontal Flip 90 | for j in range(hFlip): 91 | img = img[:, ::-1, :].copy() 92 | return img 93 | 94 | def __getitem__(self, idx): 95 | 96 | f = h5py.File(self.data_paths[idx], 'r') 97 | hsi = f['hsi'][:] 98 | f.close() 99 | 100 | patch_size_h = self.patch_size[0] 101 | patch_size_w = self.patch_size[1] 102 | 103 | 104 | if self.arg: 105 | rotTimes = random.randint(0, 3) 106 | vFlip = random.randint(0, 1) 107 | hFlip = random.randint(0, 1) 108 | hsi = self.arguement(hsi, rotTimes, vFlip, hFlip) 109 | 110 | random_h = random.randint(0, hsi.shape[1] - patch_size_h -1) 111 | random_w = random.randint(0, hsi.shape[2] - patch_size_w -1) 112 | output_hsi = hsi[:, random_h:random_h+patch_size_h, random_w:random_w+patch_size_w] 113 | output_hsi = output_hsi.astype(np.float32) 114 | output_hsi = output_hsi / output_hsi.max() 115 | 116 | return np.ascontiguousarray(output_hsi) 117 | 118 | def __len__(self): 119 | return self.img_num 120 | 121 | 122 | 123 | 124 | 125 | 126 | class TrainDataset_V2(Dataset): 127 | def __init__(self, data_path, patch_size, arg=False): 128 | 129 | self.arg = arg 130 | self.data_path = data_path 131 | self.patch_size = patch_size 132 | self.select_index = np.concatenate((np.arange(0,61,1), np.arange(62, 132, 2))) 133 | data_list = os.listdir(data_path) 134 | data_list.sort() 135 | 136 | self.data_list = data_list 137 | self.img_num = len(self.data_list) 138 | 139 | def arguement(self, img, rotTimes, vFlip, hFlip): 140 | # Random rotation 141 | for j in range(rotTimes): 142 | img = np.rot90(img.copy(), axes=(1, 2)) 143 | # Random vertical Flip 144 | for j in range(vFlip): 145 | img = img[:, :, ::-1].copy() 146 | # Random horizontal Flip 147 | for j in range(hFlip): 148 | img = img[:, ::-1, :].copy() 149 | return img 150 | 151 | def __getitem__(self, idx): 152 | 153 | 154 | f = h5py.File(self.data_path + self.data_list[idx], 'r') 155 | hsi = f['hsi'][:] 156 | hsi = hsi[self.select_index, :, :] 157 | f.close() 158 | 159 | patch_size_h = self.patch_size[0] 160 | patch_size_w = self.patch_size[1] 161 | 162 | 163 | if self.arg: 164 | rotTimes = random.randint(0, 3) 165 | vFlip = random.randint(0, 1) 166 | hFlip = random.randint(0, 1) 167 | hsi = self.arguement(hsi, rotTimes, vFlip, hFlip) 168 | 169 | random_h = random.randint(0,hsi.shape[1] - patch_size_h -1) 170 | random_w = random.randint(0,hsi.shape[2] - patch_size_w -1) 171 | output_hsi = hsi[:, random_h:random_h+patch_size_h, random_w:random_w+patch_size_w] 172 | output_hsi = output_hsi.astype(np.float32) 173 | output_hsi = output_hsi / output_hsi.max() 174 | 175 | return np.ascontiguousarray(output_hsi) 176 | 177 | def __len__(self): 178 | return self.img_num 179 | 180 | class ValidDataset_V2(Dataset): 181 | def __init__(self, data_path, patch_size, arg=False): 182 | 183 | self.arg = arg 184 | self.data_paths = [] 185 | self.patch_size = patch_size 186 | 187 | self.select_index = np.concatenate((np.arange(0,61,1), np.arange(62, 132, 2))) 188 | 189 | data_list = os.listdir(data_path) 190 | data_list.sort() 191 | for i in range(len(data_list)): 192 | 193 | self.data_paths.append(data_path + data_list[i]) 194 | 195 | self.img_num = len(self.data_paths) 196 | 197 | def arguement(self, img, rotTimes, vFlip, hFlip): 198 | # Random rotation 199 | for j in range(rotTimes): 200 | img = np.rot90(img.copy(), axes=(1, 2)) 201 | # Random vertical Flip 202 | for j in range(vFlip): 203 | img = img[:, :, ::-1].copy() 204 | # Random horizontal Flip 205 | for j in range(hFlip): 206 | img = img[:, ::-1, :].copy() 207 | return img 208 | 209 | def __getitem__(self, idx): 210 | 211 | f = h5py.File(self.data_paths[idx], 'r') 212 | hsi = f['hsi'][:] 213 | hsi = hsi[self.select_index, :, :] 214 | f.close() 215 | 216 | patch_size_h = self.patch_size[0] 217 | patch_size_w = self.patch_size[1] 218 | 219 | 220 | if self.arg: 221 | rotTimes = random.randint(0, 3) 222 | vFlip = random.randint(0, 1) 223 | hFlip = random.randint(0, 1) 224 | hsi = self.arguement(hsi, rotTimes, vFlip, hFlip) 225 | 226 | random_h = random.randint(0, hsi.shape[1] - patch_size_h -1) 227 | random_w = random.randint(0, hsi.shape[2] - patch_size_w -1) 228 | output_hsi = hsi[:, random_h:random_h+patch_size_h, random_w:random_w+patch_size_w] 229 | output_hsi = output_hsi.astype(np.float32) 230 | output_hsi = output_hsi / output_hsi.max() 231 | 232 | return np.ascontiguousarray(output_hsi) 233 | 234 | def __len__(self): 235 | return self.img_num 236 | 237 | 238 | 239 | 240 | 241 | class TestDataset_MOS(Dataset): 242 | def __init__(self, data_path, data_list, start_dir, image_size, arg=False): 243 | 244 | self.arg = arg 245 | self.data_path = data_path 246 | 247 | self.start_dir = start_dir 248 | self.image_size = image_size 249 | 250 | self.data_list = data_list 251 | 252 | self.MOS_list = [] 253 | 254 | for i in range(len(data_list)): 255 | 256 | bmp = cv2.imread(self.data_path + self.data_list[i])[:, :, 0] 257 | bmp = bmp[self.start_dir[0]:self.start_dir[0]+self.image_size[0], self.start_dir[1]:self.start_dir[1] + self.image_size[1]] 258 | bmp = bmp / bmp.max() 259 | bmp = bmp.astype(np.float32) 260 | mos = np.expand_dims(bmp, axis=0) 261 | self.MOS_list.append(mos) 262 | 263 | self.img_num = len(self.data_list) 264 | 265 | def __getitem__(self, idx): 266 | mos_name = self.data_list[idx] 267 | mos = self.MOS_list[idx] 268 | 269 | return np.ascontiguousarray(mos), mos_name 270 | 271 | def __len__(self): 272 | return self.img_num 273 | 274 | -------------------------------------------------------------------------------- /my_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import logging 6 | import numpy as np 7 | import os 8 | import hdf5storage 9 | from math import exp 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | def save_matv73(mat_name, var_name, var): 13 | hdf5storage.savemat(mat_name, {var_name: var}, format='7.3', store_python_metadata=True) 14 | 15 | class AverageMeter(object): 16 | def __init__(self): 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum = self.sum + val * n 28 | self.count = self.count + n 29 | self.avg = self.sum / self.count 30 | 31 | 32 | def initialize_logger(file_dir): 33 | logger = logging.getLogger() 34 | fhandler = logging.FileHandler(filename=file_dir, mode='a') 35 | formatter = logging.Formatter('%(asctime)s - %(message)s', "%Y-%m-%d %H:%M:%S") 36 | fhandler.setFormatter(formatter) 37 | logger.addHandler(fhandler) 38 | logger.setLevel(logging.INFO) 39 | return logger 40 | 41 | def save_checkpoint(model_path, epoch, iteration, model, optimizer): 42 | state = { 43 | 'epoch': epoch, 44 | 'iter': iteration, 45 | 'state_dict': model.state_dict(), 46 | 'optimizer': optimizer.state_dict(), 47 | } 48 | 49 | torch.save(state, os.path.join(model_path, 'net_%depoch.pth' % epoch)) 50 | 51 | class Loss_MRAE(nn.Module): 52 | def __init__(self): 53 | super(Loss_MRAE, self).__init__() 54 | 55 | def forward(self, outputs, label): 56 | assert outputs.shape == label.shape 57 | error = torch.abs(outputs - label + 1e-4) / (label + 1e-4) 58 | 59 | mrae = torch.mean(error) 60 | return mrae 61 | 62 | class Loss_RMSE(nn.Module): 63 | def __init__(self): 64 | super(Loss_RMSE, self).__init__() 65 | 66 | def forward(self, outputs, label): 67 | assert outputs.shape == label.shape 68 | error = outputs-label 69 | sqrt_error = torch.pow(error,2) 70 | rmse = torch.sqrt(torch.mean(sqrt_error)) 71 | return rmse 72 | 73 | 74 | 75 | class Loss_SAM(nn.Module): 76 | def __init__(self): 77 | super(Loss_SAM, self).__init__() 78 | 79 | def forward(self, outputs, labels): 80 | assert outputs.shape == labels.shape 81 | num = torch.sum(outputs * labels, 1) 82 | den = torch.sqrt(torch.sum(outputs * outputs, 1)) * torch.sqrt(torch.sum(labels * labels, 1)) 83 | sam = torch.arccos((num) / (den)).mean() 84 | return sam 85 | 86 | 87 | 88 | class Loss_Fidelity(nn.Module): 89 | def __init__(self): 90 | super(Loss_Fidelity, self).__init__() 91 | 92 | def forward(self, outputs, labels): 93 | assert outputs.shape == labels.shape 94 | num = torch.sum(outputs * labels, 1) 95 | den = torch.sqrt(torch.sum(outputs * outputs, 1)) * torch.sqrt(torch.sum(labels * labels, 1)) 96 | fidelity = ((num) / (den)).mean() 97 | return fidelity 98 | 99 | 100 | class Loss_TV(nn.Module): 101 | def __init__(self, TVLoss_weight: float=1): 102 | super(Loss_TV, self).__init__() 103 | self.weight = TVLoss_weight 104 | 105 | def forward(self, outputs, labels): 106 | 107 | _, _, h, w = outputs.shape 108 | 109 | h_tv = torch.abs(outputs[:, :, 1:, :] - labels[:, :, :h-1, :]).mean() 110 | w_tv = torch.abs(outputs[:, :, :, 1:] - labels[:, :, :, :w-1]).mean() 111 | 112 | loss = self.weight*(h_tv + w_tv) 113 | 114 | return loss 115 | 116 | 117 | 118 | class Loss_MSE(nn.Module): 119 | def __init__(self): 120 | super(Loss_MSE, self).__init__() 121 | 122 | def forward(self, outputs, label): 123 | assert outputs.shape == label.shape 124 | error = outputs-label 125 | sqrt_error = torch.pow(error,2) 126 | mse = torch.mean(sqrt_error) 127 | return mse 128 | 129 | 130 | class Loss_MAE(nn.Module): 131 | def __init__(self): 132 | super(Loss_MAE, self).__init__() 133 | 134 | def forward(self, outputs, label): 135 | assert outputs.shape == label.shape 136 | error = outputs-label 137 | l1_error = torch.abs(error) 138 | mae = torch.mean(l1_error) 139 | return mae 140 | 141 | 142 | class Loss_PSNR(nn.Module): 143 | def __init__(self): 144 | super(Loss_PSNR, self).__init__() 145 | 146 | def forward(self, im_true, im_fake, data_range=1.0): 147 | N = im_true.size()[0] 148 | C = im_true.size()[1] 149 | H = im_true.size()[2] 150 | W = im_true.size()[3] 151 | Itrue = im_true.clamp(0., 1.).mul_(data_range) 152 | Itrue = Itrue.reshape(N, C * H * W) 153 | Ifake = im_fake.clamp(0., 1.).mul_(data_range) 154 | Ifake = Ifake.reshape(N, C * H * W) 155 | 156 | mse = nn.MSELoss(reduction='none') 157 | err = mse(Itrue, Ifake).sum(dim=1, keepdim=True).div_(C * H * W) 158 | 159 | psnr = 10. * torch.log((data_range ** 2) / err) / np.log(10.) 160 | return torch.mean(psnr) 161 | 162 | 163 | 164 | 165 | #When traning or testing the HyperspecI-V2, we must eliminate the zero in HSIs 166 | class Loss_MRAE_V2(nn.Module): 167 | def __init__(self): 168 | super(Loss_MRAE_V2, self).__init__() 169 | 170 | def forward(self, outputs, labels): 171 | assert outputs.shape == labels.shape 172 | 173 | #Remove zero elements from the denominator 174 | 175 | b, c, h, w = labels.size() 176 | labels = labels.permute(0, 2, 3, 1) 177 | labels = labels.reshape(-1, c) 178 | 179 | outputs = outputs.permute(0, 2, 3, 1) 180 | outputs = outputs.reshape(-1, c) 181 | column_sum = labels.sum(dim=1) 182 | non_zero_columns = column_sum != 0 183 | non_zero_column_indices = torch.nonzero(non_zero_columns).squeeze() 184 | filtered_labels = labels[non_zero_column_indices, :] 185 | filtered_outputs = outputs[non_zero_column_indices, :] 186 | error = torch.abs(filtered_outputs - filtered_labels + 1e-4) / (filtered_labels + 1e-4) 187 | mrae = torch.mean(error) 188 | return mrae 189 | 190 | 191 | 192 | class Loss_SAM_V2(nn.Module): 193 | def __init__(self): 194 | super(Loss_SAM_V2, self).__init__() 195 | 196 | def forward(self, outputs, labels): 197 | assert outputs.shape == labels.shape 198 | 199 | 200 | b, c, h, w = outputs.size() 201 | labels = labels.permute(0, 2, 3, 1) 202 | labels = labels.reshape(-1, c) 203 | outputs = outputs.permute(0, 2, 3, 1) 204 | outputs = outputs.reshape(-1, c) 205 | 206 | #Remove zero elements from the denominator 207 | 208 | column_sum = labels.sum(dim=1) 209 | non_zero_columns = column_sum != 0 210 | non_zero_column_indices = torch.nonzero(non_zero_columns).squeeze() 211 | 212 | filtered_labels = labels[non_zero_column_indices, :] 213 | filtered_outputs = outputs[non_zero_column_indices, :] 214 | 215 | 216 | num = torch.sum(filtered_outputs * filtered_labels, 1) 217 | den = torch.sqrt(torch.sum(filtered_outputs * filtered_outputs, 1)) * torch.sqrt(torch.sum(filtered_labels * filtered_labels, 1)) 218 | sam = torch.arccos((num) / (den)).mean() 219 | return sam 220 | 221 | 222 | class Loss_Fidelity_V2(nn.Module): 223 | def __init__(self): 224 | super(Loss_Fidelity_V2, self).__init__() 225 | 226 | def forward(self, outputs, labels): 227 | assert outputs.shape == labels.shape 228 | 229 | b, c, h, w = outputs.size() 230 | labels = labels.permute(0, 2, 3, 1) 231 | labels = labels.reshape(-1, c) 232 | outputs = outputs.permute(0, 2, 3, 1) 233 | outputs = outputs.reshape(-1, c) 234 | 235 | #Remove zero elements from the denominator 236 | 237 | column_sum = labels.sum(dim=1) 238 | 239 | non_zero_columns = column_sum != 0 240 | non_zero_column_indices = torch.nonzero(non_zero_columns).squeeze() 241 | 242 | filtered_labels = labels[non_zero_column_indices, :] 243 | filtered_outputs = outputs[non_zero_column_indices, :] 244 | 245 | num = torch.sum(filtered_outputs * filtered_labels, 1) 246 | den = torch.sqrt(torch.sum(filtered_outputs * filtered_outputs, 1)) * torch.sqrt(torch.sum(filtered_labels * filtered_labels, 1)) 247 | fidelity = ((num) / (den)).mean() 248 | return fidelity 249 | 250 | 251 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 252 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 253 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 254 | 255 | mu1_sq = mu1.pow(2) 256 | mu2_sq = mu2.pow(2) 257 | mu1_mu2 = mu1 * mu2 258 | 259 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 260 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 261 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 262 | 263 | C1 = 0.01 ** 2 264 | C2 = 0.03 ** 2 265 | 266 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 267 | 268 | if size_average: 269 | return ssim_map.mean() 270 | else: 271 | return ssim_map.mean(1).mean(1).mean(1) 272 | 273 | 274 | def gaussian(window_size, sigma): 275 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 276 | return gauss / gauss.sum() 277 | 278 | def create_window(window_size, channel): 279 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 280 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 281 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 282 | return window 283 | class Loss_SSIM(torch.nn.Module): 284 | def __init__(self, window_size=11, size_average=True): 285 | super(Loss_SSIM, self).__init__() 286 | self.window_size = window_size 287 | self.size_average = size_average 288 | self.channel = 1 289 | self.window = create_window(window_size, self.channel) 290 | 291 | def forward(self, img1, img2): 292 | (_, channel, _, _) = img1.size() 293 | 294 | if channel == self.channel and self.window.data.type() == img1.data.type(): 295 | window = self.window 296 | else: 297 | window = create_window(self.window_size, channel) 298 | 299 | if img1.is_cuda: 300 | window = window.cuda(img1.get_device()) 301 | window = window.type_as(img1) 302 | 303 | self.window = window 304 | self.channel = channel 305 | 306 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 307 | 308 | 309 | 310 | def time2file_name(time): 311 | year = time[0:4] 312 | month = time[5:7] 313 | day = time[8:10] 314 | hour = time[11:13] 315 | minute = time[14:16] 316 | second = time[17:19] 317 | time_filename = year + '_' + month + '_' + day + '_' + hour + '_' + minute + '_' + second 318 | return time_filename 319 | 320 | def record_loss(loss_csv, epoch, iteration, epoch_time, lr, train_loss, test_loss): 321 | """ Record many results.""" 322 | loss_csv.write('{},{},{},{},{},{}\n'.format(epoch, iteration, epoch_time, lr, train_loss, test_loss)) 323 | loss_csv.flush() 324 | loss_csv.close 325 | 326 | -------------------------------------------------------------------------------- /test_HyperspecI_V1.py: -------------------------------------------------------------------------------- 1 | import hdf5storage 2 | import torch 3 | import argparse 4 | import os 5 | import torch.backends.cudnn as cudnn 6 | from torch.utils.data import DataLoader 7 | from getdataset import TestDataset_MOS 8 | from my_utils import initialize_logger 9 | import torch.utils.data 10 | from architecture import model_generator 11 | import numpy as np 12 | import h5py 13 | 14 | parser = argparse.ArgumentParser(description="Reconstruct hypersepctral images from measurements") 15 | parser.add_argument("--method", type=str, default='V1_srnet', help='Model') 16 | parser.add_argument("--gpu_id", type=str, default='0', help='path log files') 17 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 18 | parser.add_argument("--mask_path", type=str, default='./MASK/Mask_HyperspecI_V1.mat', help='path log files') 19 | parser.add_argument("--start_dir", type=int, default=(0, 0), help="size of test image coordinate") 20 | parser.add_argument("--image_size", type=int, default=(2048, 2048), help="size of test image") 21 | parser.add_argument("--pretrained_model_path", type=str, default='./Model_zoo/SRNet_V1.pth', help='path log files') 22 | parser.add_argument("--image_folder", type=str, default= './Measurements_Test/HyperspecI_V1/', help='path log files') 23 | parser.add_argument("--save_folder", type=str, default= './Measurements_Test/Output_HyperspecI_V1/', help='path log files') 24 | 25 | 26 | opt = parser.parse_args() 27 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' 28 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 29 | 30 | 31 | def main(): 32 | cudnn.benchmark = True 33 | mask_init = hdf5storage.loadmat(opt.mask_path)['mask'] 34 | mask = np.maximum(mask_init, 0) 35 | mask = mask / mask.max() 36 | mask = mask.astype(np.float32) 37 | mask = torch.from_numpy(mask) 38 | mask = mask.cuda() 39 | mask = mask.unsqueeze(0) 40 | model = model_generator(opt.method, opt.pretrained_model_path) 41 | total_params = sum(p.numel() for p in model.parameters()) 42 | print(f'{total_params:,} total parameters.') 43 | if torch.cuda.is_available(): 44 | model.cuda() 45 | if not os.path.exists(opt.save_folder): 46 | os.makedirs(opt.save_folder) 47 | 48 | test_list = os.listdir(opt.image_folder) 49 | test_list.sort() 50 | 51 | test_data = TestDataset_MOS(data_path=opt.image_folder, data_list=test_list, start_dir=opt.start_dir, image_size=opt.image_size, arg=False) 52 | test_loader = DataLoader(dataset=test_data, batch_size=opt.batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) 53 | mask_test = mask.repeat(opt.batch_size, 1, 1, 1) 54 | model.eval() 55 | 56 | 57 | for i, (MOS, mos_name) in enumerate(test_loader): 58 | 59 | MOS = MOS.cuda() 60 | with torch.no_grad(): 61 | outputs = model(MOS, mask_test) 62 | for k in range(len(mos_name)): 63 | output_hsi = outputs[k, :, :, :].squeeze() 64 | output_hsi = torch.maximum(output_hsi, torch.tensor(0)) 65 | output_hsi = output_hsi / output_hsi.max() 66 | output_hsi = output_hsi.cpu().numpy() 67 | input_mos = MOS[k, :, :, :].squeeze() 68 | input_mos = input_mos.cpu().numpy() 69 | print('input_mos>>>>>>>>>>', mos_name[k], input_mos.shape, input_mos.max(), input_mos.mean(), input_mos.min()) 70 | print('outputs>>>>>>>>>>', mos_name[k], output_hsi.shape, output_hsi.max(), output_hsi.mean(), output_hsi.min()) 71 | 72 | 73 | f = h5py.File(opt.save_folder + 'HSI_R_' + mos_name[k][:-4] + '.h5', 'w') 74 | f['mos'] = input_mos 75 | f['hsi_R'] = output_hsi 76 | f.close() 77 | 78 | if __name__ == '__main__': 79 | main() 80 | 81 | 82 | -------------------------------------------------------------------------------- /test_HyperspecI_V2.py: -------------------------------------------------------------------------------- 1 | import hdf5storage 2 | import torch 3 | import argparse 4 | import os 5 | import torch.backends.cudnn as cudnn 6 | from torch.utils.data import DataLoader 7 | from getdataset import TestDataset_MOS 8 | from my_utils import initialize_logger 9 | import torch.utils.data 10 | from architecture import model_generator 11 | import numpy as np 12 | import h5py 13 | 14 | parser = argparse.ArgumentParser(description="Reconstruct hypersepctral images from measurements") 15 | parser.add_argument("--method", type=str, default='V2_srnet', help='Model') 16 | parser.add_argument("--gpu_id", type=str, default='0', help='path log files') 17 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 18 | parser.add_argument("--mask_path", type=str, default='./MASK/Mask_HyperspecI_V2.mat', help='path log files') 19 | parser.add_argument("--start_dir", type=int, default=(0, 0), help="size of test image coordinate") 20 | parser.add_argument("--image_size", type=int, default=(1024, 1024), help="size of test image") 21 | parser.add_argument("--pretrained_model_path", type=str, default='./Model_zoo/SRNet_V2.pth', help='path log files') 22 | parser.add_argument("--image_folder", type=str, default= './Measurements_Test/HyperspecI_V2/', help='path log files') 23 | parser.add_argument("--save_folder", type=str, default= './Measurements_Test/Output_HyperspecI_V2/', help='path log files') 24 | 25 | opt = parser.parse_args() 26 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' 27 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 28 | 29 | 30 | 31 | def main(): 32 | cudnn.benchmark = True 33 | 34 | 35 | mask_init = hdf5storage.loadmat(opt.mask_path)['mask'] 36 | mask = np.maximum(mask_init, 0) 37 | mask = mask / mask.max() 38 | mask = mask.astype(np.float32) 39 | mask = torch.from_numpy(mask) 40 | mask = mask.cuda() 41 | mask = mask.unsqueeze(0) 42 | model = model_generator(opt.method, opt.pretrained_model_path) 43 | 44 | total_params = sum(p.numel() for p in model.parameters()) 45 | print(f'{total_params:,} total parameters.') 46 | if torch.cuda.is_available(): 47 | model.cuda() 48 | 49 | 50 | if not os.path.exists(opt.save_folder): 51 | os.makedirs(opt.save_folder) 52 | 53 | test_list = os.listdir(opt.image_folder) 54 | test_list.sort() 55 | 56 | test_data = TestDataset_MOS(data_path=opt.image_folder, data_list=test_list, start_dir=opt.start_dir, image_size=opt.image_size, arg=False) 57 | 58 | test_loader = DataLoader(dataset=test_data, batch_size=opt.batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) 59 | 60 | mask_test = mask.repeat(opt.batch_size, 1, 1, 1) 61 | 62 | model.eval() 63 | 64 | 65 | for i, (MOS, mos_name) in enumerate(test_loader): 66 | 67 | MOS = MOS.cuda() 68 | with torch.no_grad(): 69 | 70 | outputs = model(MOS, mask_test) 71 | 72 | for k in range(len(mos_name)): 73 | output_hsi = outputs[k, :, :, :].squeeze() 74 | output_hsi = torch.maximum(output_hsi, torch.tensor(0)) 75 | output_hsi = output_hsi / output_hsi.max() 76 | output_hsi = output_hsi.cpu().numpy() 77 | 78 | input_mos = MOS[k, :, :, :].squeeze() 79 | input_mos = input_mos.cpu().numpy() 80 | 81 | print('input_mos>>>>>>>>>>', mos_name[k], input_mos.shape, input_mos.max(), input_mos.mean(), input_mos.min()) 82 | print('outputs>>>>>>>>>>', mos_name[k], output_hsi.shape, output_hsi.max(), output_hsi.mean(), output_hsi.min()) 83 | 84 | 85 | f = h5py.File(opt.save_folder + 'HSI_R_' + mos_name[k][:-4] + '.h5', 'w') 86 | f['mos'] = input_mos 87 | f['hsi_R'] = output_hsi 88 | f.close() 89 | 90 | if __name__ == '__main__': 91 | main() 92 | 93 | 94 | -------------------------------------------------------------------------------- /train_HyperspecI_V1.py: -------------------------------------------------------------------------------- 1 | import hdf5storage 2 | import torch 3 | import argparse 4 | import os 5 | import time 6 | from torch.autograd import Variable 7 | import torch.backends.cudnn as cudnn 8 | from torch.utils.data import DataLoader 9 | from getdataset import TrainDataset_V1, ValidDataset_V1 10 | from my_utils import AverageMeter, initialize_logger, save_checkpoint, Loss_RMSE, Loss_PSNR, Loss_TV, Loss_MRAE, Loss_SAM 11 | from DataProcess import Data_Process 12 | import torch.utils.data 13 | from architecture import model_generator 14 | import numpy as np 15 | import torch.nn as nn 16 | 17 | 18 | parser = argparse.ArgumentParser(description="Model training of HyperspecI-V1") 19 | parser.add_argument("--method", type=str, default='V1_srnet', help='Model') 20 | parser.add_argument('--batch_size', type=int, default=8, help='batch size') 21 | parser.add_argument("--end_epoch", type=int, default=200, help="number of epochs") 22 | parser.add_argument("--epoch_sam_num", type=int, default=5000, help="per_epoch_iteration") 23 | parser.add_argument("--init_lr", type=float, default=4e-4, help="initial learning rate") 24 | parser.add_argument("--gpu_id", type=str, default='0', help='select gpu') 25 | parser.add_argument("--pretrained_model_path", type=str, default=None, help='pre-trained model path') 26 | parser.add_argument("--sigma", type=float, default=(0, 1 / 255, 2/255, 3/255), help="Sigma of Gaussian Noise") 27 | parser.add_argument("--mask_path", type=str, default='./MASK/Mask_HyperspecI_V1.mat', help='path of calibrated sensing matrix') 28 | parser.add_argument("--output_folder", type=str, default='./exp/HyperspecI_V1/', help='output path') 29 | parser.add_argument("--start_dir", type=int, default=(0, 0), help="size of test image coordinate") 30 | parser.add_argument("--image_size", type=int, default=(2048, 2048), help="size of test image") 31 | parser.add_argument("--train_patch_size", type=int, default=(512, 512), help="size of patch") 32 | parser.add_argument("--valid_patch_size", type=int, default=(512, 512), help="size of patch") 33 | parser.add_argument("--train_data_path", type=str, default="./Dataset_Train/HSI_400_1000/Train/", help='path datasets') 34 | parser.add_argument("--valid_data_path", type=str, default="./Dataset_Train/HSI_400_1000/Valid/", help='path datasets') 35 | 36 | 37 | 38 | opt = parser.parse_args() 39 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' 40 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 41 | criterion_rmse = Loss_RMSE() 42 | criterion_psnr = Loss_PSNR() 43 | criterion_mrae = Loss_MRAE() 44 | criterion_sam = Loss_SAM() 45 | criterion_tv = Loss_TV(TVLoss_weight=float(0.5)) 46 | data_processing = Data_Process() 47 | 48 | 49 | mask_init = hdf5storage.loadmat(opt.mask_path)['mask'] 50 | print('mask_init:', mask_init.shape) 51 | mask = mask_init[:, opt.start_dir[0]:opt.start_dir[0]+opt.image_size[0], opt.start_dir[1]:opt.start_dir[1] + opt.image_size[1]] 52 | mask = np.maximum(mask, 0) 53 | mask = mask / mask.max() 54 | mask = torch.from_numpy(mask) 55 | mask = mask.cuda() 56 | print('mask:', mask.dtype, mask.shape, mask.max(), mask.mean(), mask.min()) 57 | 58 | def main(): 59 | cudnn.benchmark = True 60 | 61 | print("\nloading dataset ...") 62 | train_data = TrainDataset_V1(data_path=opt.train_data_path, patch_size=opt.train_patch_size, arg=True) 63 | print('len(train_data):', len(train_data)) 64 | print(f"Iteration per epoch: {len(train_data)}") 65 | val_data = ValidDataset_V1(data_path=opt.valid_data_path, patch_size=opt.valid_patch_size, arg=True) 66 | print('len(valid_data):', len(val_data)) 67 | output_path = opt.output_folder 68 | 69 | # iterations 70 | per_epoch_iteration = opt.epoch_sam_num // opt.batch_size 71 | total_iteration = per_epoch_iteration*opt.end_epoch 72 | 73 | if not os.path.exists(output_path): 74 | os.makedirs(output_path) 75 | 76 | model = model_generator(opt.method, opt.pretrained_model_path) 77 | 78 | total_params = sum(p.numel() for p in model.parameters()) 79 | print(f'{total_params:,} total parameters.') 80 | if torch.cuda.is_available(): 81 | criterion_rmse.cuda() 82 | criterion_psnr.cuda() 83 | criterion_tv.cuda() 84 | criterion_mrae.cuda() 85 | 86 | start_epoch = 0 87 | iteration = start_epoch * per_epoch_iteration 88 | 89 | #opt.init_lr 90 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.init_lr, 91 | betas=(0.9, 0.999)) 92 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_iteration - iteration, eta_min=1e-6) 93 | 94 | log_dir = os.path.join(output_path, 'train.log') 95 | logger = initialize_logger(log_dir) 96 | 97 | record_rmse_loss = 10000 98 | strat_time = time.time() 99 | 100 | while iteration < total_iteration: 101 | model.train() 102 | losses = AverageMeter() 103 | 104 | train_loader = DataLoader(dataset=train_data, batch_size=opt.batch_size, shuffle=True, num_workers=8, 105 | pin_memory=True, drop_last=True) 106 | val_loader = DataLoader(dataset=val_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) 107 | 108 | for i, (HSIs) in enumerate(train_loader): 109 | 110 | HSIs = HSIs.cuda() 111 | #selecte the sub-patches radomly 112 | mask_patch = data_processing.get_random_mask_patches(mask=mask, image_size=opt.image_size, patch_size=opt.train_patch_size, batch_size=opt.batch_size) 113 | #Generate the measurements using traning HSIs and selected sub-pattern 114 | inputs, targets = data_processing.get_mos_hsi(hsi=HSIs, mask=mask_patch, sigma=opt.sigma, mos_size=opt.train_patch_size[0], hsi_input_size=opt.train_patch_size[0], hsi_target_size=opt.train_patch_size[0]) 115 | 116 | inputs = Variable(inputs) 117 | targets = Variable(targets) 118 | 119 | lr = optimizer.param_groups[0]['lr'] 120 | outputs = model(inputs, mask_patch) 121 | 122 | #calculate the hybrid loss 123 | loss_rmse = criterion_rmse(outputs, targets) 124 | loss_tv = criterion_tv(outputs, targets) 125 | loss_mrae = criterion_mrae(outputs, targets) * 0.2 126 | loss = loss_rmse + loss_tv + loss_mrae 127 | loss.backward() 128 | optimizer.step() 129 | optimizer.zero_grad() 130 | scheduler.step() 131 | 132 | losses.update(loss.data) 133 | iteration = iteration + 1 134 | 135 | if iteration % per_epoch_iteration == 0: 136 | epoch = iteration // per_epoch_iteration 137 | end_time = time.time() 138 | epoch_time = end_time - strat_time 139 | strat_time = time.time() 140 | rmse_loss, psnr_loss, mrae_loss, sam_loss = Validate(val_loader, model, mask) 141 | 142 | # Save model 143 | if torch.abs( 144 | record_rmse_loss - rmse_loss) < 0.0001 or rmse_loss < record_rmse_loss or iteration % 10000 == 0: 145 | print(f'Saving to {output_path}') 146 | save_checkpoint(output_path, (epoch), iteration, model, optimizer) 147 | if rmse_loss < record_rmse_loss: 148 | record_rmse_loss = rmse_loss 149 | # print loss 150 | print(" Iter[%06d/%06d], Epoch[%06d], Time[%06d], learning rate : %.9f, Train Loss: %.9f, " 151 | "Test RMSE: %.9f, Test PSNR: %.9f, Test MRAE: %.9f, Test SAM: %.9f " 152 | % (iteration, total_iteration, epoch, epoch_time, lr, losses.avg, rmse_loss, psnr_loss, mrae_loss, sam_loss)) 153 | 154 | logger.info(" Iter[%06d/%06d], Epoch[%06d], Time[%06d], learning rate : %.9f, Train Loss: %.9f, " 155 | "Test RMSE: %.9f, Test PSNR: %.9f, Test MRAE: %.9f, Test SAM: %.9f " 156 | % (iteration, total_iteration, epoch, epoch_time, lr, losses.avg, rmse_loss, psnr_loss, mrae_loss, sam_loss)) 157 | 158 | def Validate(val_loader, model, mask): 159 | model.eval() 160 | losses_rmse = AverageMeter() 161 | losses_psnr = AverageMeter() 162 | losses_mrae = AverageMeter() 163 | losses_sam = AverageMeter() 164 | for i, (HSIs) in enumerate(val_loader): 165 | HSIs = HSIs.cuda() 166 | 167 | #selecte the sub-patches radomly 168 | mask_patch = data_processing.get_random_mask_patches(mask=mask, image_size=opt.image_size, patch_size=opt.train_patch_size, batch_size=opt.batch_size) 169 | 170 | #Generate the measurements using traning HSIs and selected sub-pattern 171 | inputs, targets = data_processing.get_mos_hsi(hsi=HSIs, mask=mask_patch, sigma=opt.sigma, mos_size=opt.valid_patch_size[0], hsi_input_size=opt.valid_patch_size[0], hsi_target_size=opt.valid_patch_size[0]) 172 | 173 | with torch.no_grad(): 174 | outputs = model(inputs, mask_patch) 175 | 176 | loss_rmse = criterion_rmse(outputs, targets) 177 | loss_psnr = criterion_psnr(outputs, targets) 178 | loss_mrae = criterion_mrae(outputs, targets) 179 | loss_sam = criterion_sam(outputs, targets) 180 | losses_psnr.update(loss_psnr.data) 181 | losses_rmse.update(loss_rmse.data) 182 | losses_mrae.update(loss_mrae.data) 183 | losses_sam.update(loss_sam.data) 184 | 185 | return losses_rmse.avg, losses_psnr.avg, losses_mrae.avg, losses_sam.avg 186 | 187 | 188 | if __name__ == '__main__': 189 | main() 190 | 191 | 192 | -------------------------------------------------------------------------------- /train_HyperspecI_V2.py: -------------------------------------------------------------------------------- 1 | import hdf5storage 2 | import torch 3 | import argparse 4 | import os 5 | import time 6 | from torch.autograd import Variable 7 | import torch.backends.cudnn as cudnn 8 | from torch.utils.data import DataLoader 9 | from getdataset import TrainDataset_V2, ValidDataset_V2 10 | from my_utils import AverageMeter, initialize_logger, save_checkpoint, Loss_RMSE, Loss_PSNR, Loss_TV, Loss_MRAE_V2, Loss_SAM_V2 11 | from DataProcess import Data_Process 12 | import torch.utils.data 13 | from architecture import model_generator 14 | import numpy as np 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Model training of HyperspecI-V2") 18 | parser.add_argument("--method", type=str, default='V2_srnet', help='Model') 19 | parser.add_argument('--batch_size', type=int, default=8, help='batch size') 20 | parser.add_argument("--end_epoch", type=int, default=200, help="number of epochs") 21 | parser.add_argument("--epoch_sam_num", type=int, default=5000, help="per_epoch_iteration") 22 | parser.add_argument("--init_lr", type=float, default=4e-4, help="initial learning rate") 23 | parser.add_argument("--gpu_id", type=str, default='0', help='select gpu') 24 | parser.add_argument("--pretrained_model_path", type=str, default=None, help='pre-trained model path') 25 | parser.add_argument("--sigma", type=int, default=(0, 1/255, 3/255, 5/255), help="Sigma of Gaussian Noise") 26 | parser.add_argument("--mask_path", type=str, default='./MASK/Mask_HyperspecI_V2.mat', help='path of calibrated sensing matrix') 27 | parser.add_argument("--output_folder", type=str, default='./exp/HyperspecI_V2/', help='output path') 28 | parser.add_argument("--start_dir", type=int, default=(0, 0), help="size of test image coordinate") 29 | parser.add_argument("--image_size", type=int, default=(1024, 1024), help="size of test image") 30 | parser.add_argument("--train_patch_size", type=int, default=(512, 512), help="size of patch") 31 | parser.add_argument("--valid_patch_size", type=int, default=(512, 512), help="size of patch") 32 | parser.add_argument("--train_data_path", type=str, default="./Dataset_Train/HSI_400_1700/Train/", help='path datasets') 33 | parser.add_argument("--valid_data_path", type=str, default="./Dataset_Train/HSI_400_1700/Valid/", help='path datasets') 34 | 35 | opt = parser.parse_args() 36 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' 37 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 38 | 39 | criterion_rmse = Loss_RMSE() 40 | criterion_psnr = Loss_PSNR() 41 | criterion_mrae = Loss_MRAE_V2() 42 | criterion_sam = Loss_SAM_V2() 43 | criterion_tv = Loss_TV(TVLoss_weight=float(0.5)) 44 | data_processing = Data_Process() 45 | 46 | 47 | mask_init = hdf5storage.loadmat(opt.mask_path)['mask'] 48 | mask = mask_init[:, opt.start_dir[0]:opt.start_dir[0]+opt.image_size[0], opt.start_dir[1]:opt.start_dir[1] + opt.image_size[1]] 49 | mask = np.maximum(mask, 0) 50 | mask = mask / mask.max() 51 | mask = torch.from_numpy(mask) 52 | mask = mask.cuda() 53 | 54 | def main(): 55 | cudnn.benchmark = True 56 | 57 | print("\nloading dataset ...") 58 | train_data = TrainDataset_V2(data_path=opt.train_data_path, patch_size=opt.train_patch_size, arg=True) 59 | print('len(train_data):', len(train_data)) 60 | print(f"Iteration per epoch: {len(train_data)}") 61 | val_data = ValidDataset_V2(data_path=opt.valid_data_path, patch_size=opt.valid_patch_size, arg=True) 62 | 63 | print('len(valid_data):', len(val_data)) 64 | output_path = opt.output_folder 65 | 66 | # iterations 67 | per_epoch_iteration = opt.epoch_sam_num // opt.batch_size 68 | total_iteration = per_epoch_iteration*opt.end_epoch 69 | 70 | if not os.path.exists(output_path): 71 | os.makedirs(output_path) 72 | 73 | model = model_generator(opt.method, opt.pretrained_model_path) 74 | 75 | total_params = sum(p.numel() for p in model.parameters()) 76 | print(f'{total_params:,} total parameters.') 77 | if torch.cuda.is_available(): 78 | criterion_rmse.cuda() 79 | criterion_psnr.cuda() 80 | criterion_tv.cuda() 81 | criterion_mrae.cuda() 82 | 83 | start_epoch = 0 84 | iteration = start_epoch * per_epoch_iteration 85 | 86 | #opt.init_lr 87 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.init_lr, 88 | betas=(0.9, 0.999)) 89 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_iteration - iteration, eta_min=1e-6) 90 | 91 | log_dir = os.path.join(output_path, 'train.log') 92 | logger = initialize_logger(log_dir) 93 | 94 | record_rmse_loss = 10000 95 | strat_time = time.time() 96 | 97 | while iteration < total_iteration: 98 | model.train() 99 | losses = AverageMeter() 100 | 101 | train_loader = DataLoader(dataset=train_data, batch_size=opt.batch_size, shuffle=True, num_workers=8, 102 | pin_memory=True, drop_last=True) 103 | val_loader = DataLoader(dataset=val_data, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) 104 | 105 | for i, (HSIs) in enumerate(train_loader): 106 | 107 | HSIs = HSIs.cuda() 108 | #selecte the sub-patches radomly 109 | mask_patch = data_processing.get_random_mask_patches(mask=mask, image_size=opt.image_size, patch_size=opt.train_patch_size, batch_size=opt.batch_size) 110 | #Generate the measurements using traning HSIs and selected sub-pattern 111 | inputs, targets = data_processing.get_mos_hsi(hsi=HSIs, mask=mask_patch, sigma=opt.sigma, mos_size=opt.train_patch_size[0], hsi_input_size=opt.train_patch_size[0], hsi_target_size=opt.train_patch_size[0]) 112 | 113 | inputs = Variable(inputs) 114 | targets = Variable(targets) 115 | lr = optimizer.param_groups[0]['lr'] 116 | outputs = model(inputs, mask_patch) 117 | targets_VIS = targets[:, :55, :, :] 118 | targets_NIR = targets[:, 55:, :, :] 119 | outputs_VIS = outputs[:, :55, :, :] 120 | outputs_NIR = outputs[:, 55:, :, :] 121 | loss_vis_rmse = criterion_rmse(outputs_VIS, targets_VIS) 122 | loss_vis_tv = criterion_tv(outputs_VIS, targets_VIS) * 0.5 123 | loss_vis_mrae = criterion_mrae(outputs_VIS, targets_VIS) * 0.05 124 | loss_nir = criterion_rmse(outputs_NIR, targets_NIR) 125 | loss = loss_vis_rmse + loss_vis_tv + loss_vis_mrae + loss_nir 126 | loss.backward() 127 | optimizer.step() 128 | optimizer.zero_grad() 129 | scheduler.step() 130 | 131 | losses.update(loss.data) 132 | iteration = iteration + 1 133 | 134 | if iteration % per_epoch_iteration == 0: 135 | epoch = iteration // per_epoch_iteration 136 | end_time = time.time() 137 | epoch_time = end_time - strat_time 138 | strat_time = time.time() 139 | rmse_loss, psnr_loss, mrae_loss, sam_loss = Validate(val_loader, model, mask) 140 | 141 | # Save model 142 | if torch.abs( 143 | record_rmse_loss - rmse_loss) < 0.0001 or rmse_loss < record_rmse_loss or iteration % 10000 == 0: 144 | print(f'Saving to {output_path}') 145 | 146 | save_checkpoint(output_path, (epoch), iteration, model, optimizer) 147 | if rmse_loss < record_rmse_loss: 148 | record_rmse_loss = rmse_loss 149 | 150 | # print loss 151 | print(" Iter[%06d/%06d], Epoch[%06d], Time[%06d], learning rate : %.9f, Train Loss: %.9f, " 152 | "Test RMSE: %.9f, Test PSNR: %.9f, Test MRAE: %.9f, Test SAM: %.9f " 153 | % (iteration, total_iteration, epoch, epoch_time, lr, losses.avg, rmse_loss, psnr_loss, mrae_loss, sam_loss)) 154 | 155 | logger.info(" Iter[%06d/%06d], Epoch[%06d], Time[%06d], learning rate : %.9f, Train Loss: %.9f, " 156 | "Test RMSE: %.9f, Test PSNR: %.9f, Test MRAE: %.9f, Test SAM: %.9f " 157 | % (iteration, total_iteration, epoch, epoch_time, lr, losses.avg, rmse_loss, psnr_loss, mrae_loss, sam_loss)) 158 | 159 | def Validate(val_loader, model, mask): 160 | model.eval() 161 | losses_rmse = AverageMeter() 162 | losses_psnr = AverageMeter() 163 | losses_sam = AverageMeter() 164 | losses_mrae = AverageMeter() 165 | for i, (HSIs) in enumerate(val_loader): 166 | HSIs = HSIs.cuda() 167 | 168 | #selecte the sub-patches radomly 169 | mask_patch = data_processing.get_random_mask_patches(mask=mask, image_size=opt.image_size, patch_size=opt.valid_patch_size, batch_size=opt.batch_size) 170 | 171 | #Generate the measurements using traning HSIs and selected sub-pattern 172 | inputs, targets = data_processing.get_mos_hsi(hsi=HSIs, mask=mask_patch, sigma=opt.sigma, mos_size=opt.valid_patch_size[0], hsi_input_size=opt.valid_patch_size[0], hsi_target_size=opt.valid_patch_size[0]) 173 | 174 | with torch.no_grad(): 175 | outputs = model(inputs, mask_patch) 176 | 177 | loss_rmse = criterion_rmse(outputs, targets) 178 | loss_psnr = criterion_psnr(outputs, targets) 179 | loss_mrae = criterion_mrae(outputs, targets) 180 | loss_sam = criterion_sam(outputs, targets) 181 | losses_psnr.update(loss_psnr.data) 182 | losses_rmse.update(loss_rmse.data) 183 | losses_sam.update(loss_sam.data) 184 | losses_mrae.update(loss_mrae.data) 185 | 186 | return losses_rmse.avg, losses_psnr.avg, losses_mrae.avg, losses_sam.avg 187 | 188 | 189 | if __name__ == '__main__': 190 | main() 191 | 192 | 193 | --------------------------------------------------------------------------------