├── README.md ├── config.py ├── dataset.py ├── inference.py ├── isp └── ISP_CNN.pth ├── load_data.py ├── models.py ├── netloss.py ├── structure.py ├── torchsummary.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # EMVD 2 | Efficient Multi-Stage Video Denoising With Recurrent Spatio-Temporal Fusion. 3 | 4 | EMVD is an efficient video denoising method which recursively exploit the spatio temporal correlation inherently present in natural videos through multiple cascading processing stages applied in a recurrent fashion, namely temporal fusion, spatial denoising, and spatio-temporal refinement. 5 | 6 | # Overview 7 | This repo. is an ***unofficial*** version od EMVD mentioned by **Matteo Maggioni, Yibin Huang, Cheng Li, Shuai Xiao, Zhongqian Fu, Fenglong Song** in CVPR 2021. 8 | 9 | It is a **Pytorch** implementation. 10 | 11 | # Paper 12 | - https://openaccess.thecvf.com/content/CVPR2021/papers/Maggioni_Efficient_Multi-Stage_Video_Denoising_With_Recurrent_Spatio-Temporal_Fusion_CVPR_2021_paper.pdf 13 | - https://openaccess.thecvf.com/content/CVPR2021/supplemental/Maggioni_Efficient_Multi-Stage_Video_CVPR_2021_supplemental.pdf 14 | 15 | # Requirements 16 | 1. PyTorch>=1.6 17 | 2. Numpy 18 | 3. scikti-image 19 | 4. tensorboardX (for visualization of loss, PSNR and images) 20 | 5. torchstat (for computing GFLOPs) 21 | 22 | # Code 23 | 1. `config.py` is the code for setting hyperparameters. 24 | 2. `dataset.py` and load_data.py is the code for loading data from dataset. 25 | 3. `train.py` is the code for training process 26 | 4. `inference.py` is the code for validation process. 27 | 5. `models.py` and `./isp/ISP_CNN.pth` is called by `inference.py` for converting .tiff to .png, which refer to the code RViDeNet(https://github.com/cao-cong/RViDeNet). 28 | 29 | # Dataset 30 | CRVD Dataset (https://github.com/cao-cong/RViDeNet) 31 | 32 | # Usage 33 | modify `data_root` in `config.py`, and `gt_name/noisy_name` in function `decode_data` in`load_data.py`, and run `train.py` for training process. After convergence, run `inference.py` for validation process. 34 | 35 | # Results 36 | ISO average raw psnr:42.02, iso frame average raw ssim:0.9800 in CRVD datasets (~5.38GFLPs), which is still lower than the experiment results mentioned in paper. 37 | 38 | # Acknowledgement 39 | This implementations are inspired by following projects: 40 | - [RViDeNet] (https://github.com/cao-cong/RViDeNet) 41 | 42 | 43 | *Many thanks for coming here! It will be highly appreciated if you offer any suggestion.* 44 | 45 | *Support me by starring or forking this repo., please.* 46 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | debug = 1 5 | 6 | # gpu 7 | ngpu = 1 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | 10 | # generate data 11 | data_root = ['./datset/', './CRVD_dataset/'] 12 | output_root = './results/' 13 | 14 | image_height = 128 15 | image_width = 128 16 | batch_size = 16 17 | frame_num = 25 18 | num_workers = 4 19 | 20 | 21 | BLACK_LEVEL = 64 22 | VALID_VALUE = 959 23 | 24 | 25 | # log 26 | model_name = './model/' 27 | debug_dir = os.path.join('debug') 28 | log_dir = os.path.join(debug_dir, 'log') 29 | log_step = 10 # save the train log per log_step 30 | 31 | # model store 32 | model_save_root = os.path.join(model_name, 'model.pth') 33 | best_model_save_root = os.path.join(model_name, 'model_best.pth') 34 | 35 | # pretrained model path 36 | checkpoint = None if not os.path.exists(os.path.join(model_name, 'model.pth')) else os.path.join(model_name, 'model.pth') 37 | start_epoch = 0 38 | start_iter = 0 39 | 40 | # parameter of train 41 | learning_rate = 0.0001 42 | epoch = int(1e8) 43 | 44 | # validation 45 | valid_start_iter = 500 46 | valid_step = 50 47 | vis_data = 1 # whether to visualize noisy and gt data 48 | # clip threshold 49 | image_min_value = 0 50 | image_max_value = 1 51 | image_norm_value = 1 52 | 53 | label_min_value = 0 54 | label_max_value = 255 55 | label_norm_value = 1 56 | 57 | 58 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.stats import poisson 4 | from skimage.metrics import structural_similarity as compare_ssim 5 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 6 | import torch.nn.functional as F 7 | import config as cfg 8 | import torch.nn as nn 9 | 10 | def pack_gbrg_raw(raw): 11 | #pack GBRG Bayer raw to 4 channels 12 | black_level = 240 13 | white_level = 2**12-1 14 | im = raw.astype(np.float32) 15 | im = np.maximum(im - black_level, 0) / (white_level-black_level) 16 | 17 | im = np.expand_dims(im, axis=2) 18 | img_shape = im.shape 19 | H = img_shape[0] 20 | W = img_shape[1] 21 | 22 | out = np.concatenate((im[1:H:2, 0:W:2, :], # r 23 | im[1:H:2, 1:W:2, :], # gr 24 | im[0:H:2, 1:W:2, :], # b 25 | im[0:H:2, 0:W:2, :]), axis=2) # gb 26 | return out 27 | 28 | def depack_gbrg_raw(raw): 29 | H = raw.shape[1] 30 | W = raw.shape[2] 31 | output = np.zeros((H*2,W*2)) 32 | for i in range(H): 33 | for j in range(W): 34 | output[2*i,2*j]=raw[0,i,j,3] # gb 35 | output[2*i,2*j+1]=raw[0,i,j,2] # b 36 | output[2*i+1,2*j]=raw[0,i,j,0] # r 37 | output[2*i+1,2*j+1]=raw[0,i,j,1] # gr 38 | return output 39 | 40 | 41 | def compute_sigma(input, a, b): 42 | sigma = np.sqrt((input - 240) * a + b) 43 | return sigma 44 | 45 | 46 | def preprocess(raw): 47 | input_full = raw.transpose((0, 3, 1, 2)) 48 | input_full = torch.from_numpy(input_full) 49 | input_full = input_full.cuda() 50 | return input_full 51 | 52 | def tensor2numpy(raw): # raw: 1 * 4 * H * W 53 | input_full = raw.permute((0, 2, 3, 1)) # 1 * H * W * 4 54 | input_full = input_full.data.cpu().numpy() 55 | output = np.clip(input_full,0,1) 56 | return output 57 | 58 | def pack_rggb_raw_for_compute_ssim(raw): 59 | 60 | im = raw.astype(np.float32) 61 | im = np.expand_dims(im, axis=2) 62 | img_shape = im.shape 63 | H = img_shape[0] 64 | W = img_shape[1] 65 | out = np.concatenate((im[0:H:2, 0:W:2, :], 66 | im[0:H:2, 1:W:2, :], 67 | im[1:H:2, 1:W:2, :], 68 | im[1:H:2, 0:W:2, :]), axis=2) 69 | return out 70 | 71 | def compute_ssim_for_packed_raw(raw1, raw2): 72 | raw1_pack = pack_rggb_raw_for_compute_ssim(raw1) 73 | raw2_pack = pack_rggb_raw_for_compute_ssim(raw2) 74 | test_raw_ssim = 0 75 | for i in range(4): 76 | test_raw_ssim += compare_ssim(raw1_pack[:,:,i], raw2_pack[:,:,i], data_range=1.0) 77 | 78 | return test_raw_ssim/4 -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageFile 2 | ImageFile.LOAD_TRUNCATED_IMAGES = True 3 | import os 4 | import cv2 5 | import warnings 6 | warnings.filterwarnings('ignore') 7 | from dataset import * 8 | import config as cfg 9 | import time 10 | from PIL import Image 11 | import torch 12 | from torch import nn 13 | import torch.backends.cudnn as cudnn 14 | import structure 15 | from torch.nn import functional as F 16 | 17 | 18 | def test_big_size_raw(input_data, block_size, denoiser, a, b): 19 | stack_image = input_data 20 | hgt = np.shape(stack_image)[1] 21 | wid = np.shape(stack_image)[2] 22 | 23 | border = 32 24 | 25 | expand_raw = np.zeros(shape=[1, int(hgt * 2.5), int(wid * 2.5), 8], dtype=np.float) 26 | 27 | 28 | expand_raw[:,0:hgt, 0:wid,:] = stack_image 29 | expand_raw[:,0: hgt, border * 2: wid + border * 2,:] = stack_image 30 | expand_raw[:,border * 2:hgt + border * 2, 0:0 + wid + 0,:] = stack_image 31 | expand_raw[:,border * 2:hgt + border * 2, 0 + border * 2:0 + wid + border * 2,:] = stack_image 32 | 33 | expand_raw[:,0 + border:0 + border + hgt, 0 + border:0 + border + wid,:] = stack_image 34 | 35 | expand_res = np.zeros([1,int(hgt * 2.5), int(wid * 2.5),4], dtype=np.float) 36 | expand_fusion = np.zeros([1,int(hgt * 2.5), int(wid * 2.5),4], dtype=np.float) 37 | expand_denoise = np.zeros([1,int(hgt * 2.5), int(wid * 2.5),4], dtype=np.float) 38 | expand_gamma = np.zeros([1,int(hgt * 2.5), int(wid * 2.5),1], dtype=np.float) 39 | expand_omega = np.zeros([1, int(hgt * 2.5), int(wid * 2.5), 1], dtype=np.float) 40 | 41 | '''process''' 42 | for i in range(0 + border, hgt + border, int(block_size)): 43 | index = '%.2f' % (float(i) / float(hgt + border) * 100) 44 | print('run model : ', index, '%') 45 | for j in range(0 + border, wid + border, int(block_size)): 46 | block = expand_raw[:,i - border:i + block_size + border, j - border:j + block_size + border,:] # t frame input 47 | block = preprocess(block).float() 48 | input = block 49 | 50 | with torch.no_grad(): 51 | gamma, fusion_out, denoise_out, omega, refine_out= denoiser(input, a, b) 52 | fusion_out = tensor2numpy(fusion_out) 53 | refine_out = tensor2numpy(refine_out) 54 | denoise_out = tensor2numpy(denoise_out) 55 | gamma = tensor2numpy(F.upsample(gamma, scale_factor=2)) 56 | omega = tensor2numpy(F.upsample(omega, scale_factor=2)) 57 | expand_res[:,i:i + block_size, j:j + block_size,:] = refine_out[:,border:-border, border:-border,:] 58 | expand_fusion[:,i:i + block_size, j:j + block_size,:] = fusion_out[:,border:-border, border:-border,:] 59 | expand_denoise[:,i:i + block_size, j:j + block_size,:] = denoise_out[:,border:-border, border:-border,:] 60 | expand_gamma[:,i:i + block_size, j:j + block_size,:] = gamma[:,border:-border, border:-border,0:1] 61 | expand_omega[:, i:i + block_size, j:j + block_size, :] = omega[:, border:-border, border:-border, 0:1] 62 | 63 | refine_result = expand_res[:,border:hgt + border, border:wid + border,:] 64 | fusion_result = expand_fusion[:,border:hgt + border, border:wid + border,:] 65 | denoise_result = expand_denoise[:,border:hgt + border, border:wid + border,:] 66 | gamma_result = expand_gamma[:,border:hgt + border, border:wid + border,:] 67 | omega_result = expand_omega[:, border:hgt + border, border:wid + border, :] 68 | print('------------- Run Model Successfully -------------') 69 | 70 | return refine_result, fusion_result, denoise_result, gamma_result, omega_result 71 | 72 | 73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 74 | ngpu = cfg.ngpu 75 | cudnn.benchmark = True 76 | 77 | '''network''' 78 | checkpoint = torch.load(cfg.best_model_save_root) 79 | model = structure.MainDenoise() 80 | model = model.to(device) 81 | model.load_state_dict(checkpoint['model']) 82 | 83 | 84 | # multi gpu test 85 | if torch.cuda.is_available() and ngpu > 1: 86 | model = nn.DataParallel(model, device_ids=list(range(ngpu))) 87 | 88 | model.eval() 89 | output_dir = cfg.output_root 90 | 91 | if not os.path.exists(output_dir): 92 | os.mkdir(output_dir) 93 | 94 | iso_list = [1600, 3200, 6400, 12800, 25600] 95 | 96 | isp = torch.load('isp/ISP_CNN.pth') 97 | iso_average_raw_psnr = 0 98 | iso_average_raw_ssim = 0 99 | 100 | # for iso_ind, iso in enumerate(iso_list): 101 | for iso_ind in range(0,len(iso_list)): 102 | iso = iso_list[iso_ind] 103 | print('processing iso={}'.format(iso)) 104 | 105 | if not os.path.isdir(output_dir + 'ISO{}'.format(iso)): 106 | os.makedirs(output_dir + 'ISO{}'.format(iso)) 107 | 108 | f = open('denoise_model_test_psnr_and_ssim_on_iso{}.txt'.format(iso), 'w') 109 | 110 | context = 'ISO{}'.format(iso) + '\n' 111 | f.write(context) 112 | 113 | scene_avg_raw_psnr = 0 114 | scene_avg_raw_ssim = 0 115 | frame_list = [1, 2, 3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 7] 116 | a_list = [3.513262, 6.955588, 13.486051, 26.585953, 52.032536] 117 | b_list = [11.917691, 38.117816, 130.818508, 484.539790, 1819.818657] 118 | 119 | for scene_id in range(7,11+1): 120 | context = 'scene{}'.format(scene_id) + '\n' 121 | f.write(context) 122 | 123 | frame_avg_raw_psnr = 0 124 | frame_avg_raw_ssim = 0 125 | block_size = 512 126 | ft0_fusion_data = np.zeros([1, 540, 960, 4 * 7]) 127 | gt_fusion_data = np.zeros([1, 540, 960, 4 * 7]) 128 | for time_ind in range(0,7): 129 | raw_name = os.path.join(cfg.data_root[1],'indoor_raw_noisy/indoor_raw_noisy_scene{}/scene{}/ISO{}/frame{}_noisy0.tiff'.format(scene_id, scene_id, iso, frame_list[time_ind])) 130 | raw = cv2.imread(raw_name, -1) 131 | input_full = np.expand_dims(pack_gbrg_raw(raw), axis=0) 132 | 133 | gt_raw = cv2.imread(os.path.join(cfg.data_root[1], 134 | 'indoor_raw_gt/indoor_raw_gt_scene{}/scene{}/ISO{}/frame{}_clean_and_slightly_denoised.tiff'.format( 135 | scene_id,scene_id, iso, frame_list[time_ind])), -1).astype(np.float32) 136 | fgt = np.expand_dims(pack_gbrg_raw(gt_raw), axis=0) 137 | 138 | if time_ind == 0: 139 | ft0_fusion = input_full # 1 * 512 * 512 * 4 140 | else: 141 | ft0_fusion = ft0_fusion_data[:, :, :, (time_ind-1) * 4: (time_ind) * 4] # 1 * 512 * 512 * 4 142 | 143 | input_data = np.concatenate([ft0_fusion, input_full], axis=3) 144 | coeff_a = a_list[iso_ind] / (2 ** 12 - 1 - 240) 145 | coeff_b = b_list[iso_ind] / (2 ** 12 - 1 - 240) ** 2 146 | refine_out, fusion_out, denoise_out, gamma_out, omega_out = test_big_size_raw(input_data, block_size, model, coeff_a, coeff_b) 147 | 148 | ft0_fusion_data[:, :, :, time_ind * 4: (time_ind+1) * 4] = fusion_out 149 | 150 | test_result = depack_gbrg_raw(refine_out) 151 | test_fusion = depack_gbrg_raw(fusion_out) 152 | test_denoise = depack_gbrg_raw(denoise_out) 153 | 154 | test_gt = (gt_raw - 240) / (2 ** 12 - 1 - 240) 155 | 156 | test_raw_psnr = compare_psnr(test_gt, ( 157 | np.uint16(test_result * (2 ** 12 - 1 - 240) + 240).astype(np.float32) - 240) / ( 158 | 2 ** 12 - 1 - 240), data_range=1.0) 159 | test_raw_ssim = compute_ssim_for_packed_raw(test_gt, ( 160 | np.uint16(test_result * (2 ** 12 - 1 - 240) + 240).astype(np.float32) - 240) / ( 161 | 2 ** 12 - 1 - 240)) 162 | test_raw_psnr_input = compare_psnr(test_gt, (raw - 240) / (2 ** 12 - 1 - 240), data_range=1.0) 163 | print('scene {} frame{} test raw psnr : {}, test raw input psnr : {}, test raw ssim : {} '.format(scene_id, time_ind, test_raw_psnr, test_raw_psnr_input, test_raw_ssim)) 164 | context = 'raw psnr/ssim: {}/{}, input_psnr:{}'.format(test_raw_psnr, test_raw_ssim, test_raw_psnr_input) + '\n' 165 | f.write(context) 166 | frame_avg_raw_psnr += test_raw_psnr 167 | frame_avg_raw_ssim += test_raw_ssim 168 | 169 | output = test_result * (2 ** 12 - 1 - 240) + 240 170 | fusion = test_fusion * (2 ** 12 - 1 - 240) + 240 171 | denoise = test_denoise * (2 ** 12 - 1 - 240) + 240 172 | 173 | 174 | if cfg.vis_data: 175 | noisy_raw_frame = preprocess(np.expand_dims(pack_gbrg_raw(raw), axis=0)) 176 | noisy_srgb_frame = tensor2numpy(isp(noisy_raw_frame))[0] 177 | gt_raw_frame = np.expand_dims(pack_gbrg_raw(test_gt * (2 ** 12 - 1 - 240) + 240), axis=0) 178 | gt_srgb_frame = tensor2numpy(isp(preprocess(gt_raw_frame)))[0] 179 | cv2.imwrite(output_dir + 'ISO{}/scene{}_frame{}_noisy_sRGB.png'.format(iso, scene_id, time_ind), 180 | np.uint8(noisy_srgb_frame * 255)) 181 | cv2.imwrite(output_dir + 'ISO{}/scene{}_frame{}_gt_sRGB.png'.format(iso, scene_id, time_ind), 182 | np.uint8(gt_srgb_frame * 255)) 183 | 184 | denoised_raw_frame = preprocess(np.expand_dims(pack_gbrg_raw(output), axis=0)) 185 | denoised_srgb_frame = tensor2numpy(isp(denoised_raw_frame))[0] 186 | cv2.imwrite(output_dir + 'ISO{}/scene{}_frame{}_denoised_sRGB.png'.format(iso, scene_id, time_ind), 187 | np.uint8(denoised_srgb_frame * 255)) 188 | 189 | if cfg.vis_data: 190 | fusion_raw_frame = preprocess(np.expand_dims(pack_gbrg_raw(fusion), axis=0)) 191 | fusion_srgb_frame = tensor2numpy(isp(fusion_raw_frame))[0] 192 | cv2.imwrite(output_dir + 'ISO{}/scene{}_frame{}_fusion_sRGB.png'.format(iso, scene_id, time_ind), 193 | np.uint8(fusion_srgb_frame * 255)) 194 | 195 | denoise_midres_raw_frame = preprocess(np.expand_dims(pack_gbrg_raw(denoise), axis=0)) 196 | denoised_mid_res_srgb_frame = tensor2numpy(isp(denoise_midres_raw_frame))[0] 197 | cv2.imwrite(output_dir + 'ISO{}/scene{}_frame{}_denoised_midres_sRGB.png'.format(iso, scene_id, time_ind), 198 | np.uint8(denoised_mid_res_srgb_frame * 255)) 199 | 200 | cv2.imwrite(output_dir + 'ISO{}/scene{}_frame{}_gamma.png'.format(iso, scene_id, time_ind), np.uint8(gamma_out[0] * 255)) 201 | cv2.imwrite(output_dir + 'ISO{}/scene{}_frame{}_omega.png'.format(iso, scene_id, time_ind), 202 | np.uint8(omega_out[0] * 255)) 203 | print('gamma.max:', gamma_out.max()) 204 | print('gamma.min:', gamma_out.min()) 205 | 206 | frame_avg_raw_psnr = frame_avg_raw_psnr / 7 207 | frame_avg_raw_ssim = frame_avg_raw_ssim / 7 208 | context = 'frame average raw psnr:{},frame average raw ssim:{}'.format(frame_avg_raw_psnr, 209 | frame_avg_raw_ssim) + '\n' 210 | f.write(context) 211 | 212 | scene_avg_raw_psnr += frame_avg_raw_psnr 213 | scene_avg_raw_ssim += frame_avg_raw_ssim 214 | 215 | scene_avg_raw_psnr = scene_avg_raw_psnr / 5 216 | scene_avg_raw_ssim = scene_avg_raw_ssim / 5 217 | context = 'scene average raw psnr:{},scene frame average raw ssim:{}'.format(scene_avg_raw_psnr, 218 | scene_avg_raw_ssim) + '\n' 219 | print(context) 220 | f.write(context) 221 | iso_average_raw_psnr += scene_avg_raw_psnr 222 | iso_average_raw_ssim += scene_avg_raw_ssim 223 | 224 | iso_average_raw_psnr = iso_average_raw_psnr / len(iso_list) 225 | iso_average_raw_ssim = iso_average_raw_ssim / len(iso_list) 226 | 227 | context = 'iso average raw psnr:{},iso frame average raw ssim:{}'.format(iso_average_raw_psnr, iso_average_raw_ssim) + '\n' 228 | f.write(context) 229 | print(context) 230 | 231 | 232 | 233 | -------------------------------------------------------------------------------- /isp/ISP_CNN.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baymax-chen/EMVD/975a2f46b20798fc981bceccc1885f63aad6d870/isp/ISP_CNN.pth -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import config as cfg 5 | from dataset import * 6 | from torch.utils.data import Dataset 7 | import time 8 | iso_list = [1600, 3200, 6400, 12800, 25600] 9 | a_list = [3.513262, 6.955588, 13.486051, 26.585953, 52.032536] 10 | b_list = [11.917691, 38.117816, 130.818508, 484.539790, 1819.818657] 11 | 12 | def load_cvrd_data(shift, noisy_level, scene_ind, frame_ind, xx, yy): 13 | 14 | frame_list = [1, 2, 3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 7] 15 | 16 | gt_name = os.path.join(cfg.data_root[1], 17 | 'indoor_raw_gt/indoor_raw_gt_scene{}/scene{}/ISO{}/frame{}_clean_and_slightly_denoised.tiff'.format( 18 | scene_ind, scene_ind, iso_list[noisy_level], 19 | frame_list[frame_ind + shift])) 20 | gt_raw = cv2.imread(gt_name, -1) 21 | gt_raw_full = gt_raw 22 | gt_raw_patch = gt_raw_full[yy:yy + cfg.image_height * 2, 23 | xx:xx + cfg.image_width * 2] # 256 * 256 24 | gt_raw_pack = np.expand_dims(pack_gbrg_raw(gt_raw_patch), axis=0) # 1* 128 * 128 * 4 25 | 26 | noisy_frame_index_for_current = np.random.randint(0, 10) 27 | input_name = os.path.join(cfg.data_root[1], 28 | 'indoor_raw_noisy/indoor_raw_noisy_scene{}/scene{}/ISO{}/frame{}_noisy{}.tiff'.format( 29 | scene_ind, scene_ind, iso_list[noisy_level], 30 | frame_list[frame_ind + shift], noisy_frame_index_for_current)) 31 | noisy_raw = cv2.imread(input_name, -1) 32 | noisy_raw_full = noisy_raw 33 | noisy_patch = noisy_raw_full[yy:yy + cfg.image_height * 2, xx:xx + cfg.image_width * 2] 34 | input_pack = np.expand_dims(pack_gbrg_raw(noisy_patch), axis=0) 35 | return input_pack, gt_raw_pack 36 | 37 | 38 | def load_eval_data(noisy_level, scene_ind): 39 | input_batch_list = [] 40 | gt_raw_batch_list = [] 41 | 42 | input_pack_list = [] 43 | gt_raw_pack_list = [] 44 | 45 | xx = 200 46 | yy = 200 47 | 48 | for shift in range(0, cfg.frame_num): 49 | # load gt raw 50 | frame_ind = 0 51 | input_pack, gt_raw_pack = load_cvrd_data(shift, noisy_level, scene_ind, frame_ind, xx, yy) 52 | input_pack_list.append(input_pack) 53 | gt_raw_pack_list.append(gt_raw_pack) 54 | 55 | input_pack_frames = np.concatenate(input_pack_list, axis=3) 56 | gt_raw_pack_frames = np.concatenate(gt_raw_pack_list, axis=3) 57 | 58 | input_batch_list.append(input_pack_frames) 59 | gt_raw_batch_list.append(gt_raw_pack_frames) 60 | 61 | input_batch = np.concatenate(input_batch_list, axis=0) 62 | gt_raw_batch = np.concatenate(gt_raw_batch_list, axis=0) 63 | 64 | in_data = torch.from_numpy(input_batch.copy()).permute(0, 3, 1, 2).cuda() # 1 * (4*25) * 128 * 128 65 | gt_raw_data = torch.from_numpy(gt_raw_batch.copy()).permute(0, 3, 1, 2).cuda() # 1 * (4*25) * 128 * 128 66 | return in_data, gt_raw_data 67 | 68 | def generate_file_list(scene_list): 69 | iso_list = [1600, 3200, 6400, 12800, 25600] 70 | file_num = 0 71 | data_name = [] 72 | for scene_ind in scene_list: 73 | for iso in iso_list: 74 | for frame_ind in range(1,8): 75 | gt_name = os.path.join('ISO{}/scene{}_frame{}_gt_sRGB.png'.format( 76 | iso, scene_ind, frame_ind-1)) 77 | data_name.append(gt_name) 78 | file_num += 1 79 | 80 | random_index = np.random.permutation(file_num) 81 | data_random_list = [] 82 | for i,idx in enumerate(random_index): 83 | data_random_list.append(data_name[idx]) 84 | return data_random_list 85 | 86 | def read_img(img_name, xx, yy): 87 | raw = cv2.imread(img_name, -1) 88 | raw_full = raw 89 | raw_patch = raw_full[yy:yy + cfg.image_height * 2, 90 | xx:xx + cfg.image_width * 2] # 256 * 256 91 | raw_pack_data = pack_gbrg_raw(raw_patch) 92 | return raw_pack_data 93 | 94 | def decode_data(data_name): 95 | frame_list = [1, 2, 3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 7] 96 | H = 1080 97 | W = 1920 98 | xx = np.random.randint(0, (W - cfg.image_width * 2 + 1) / 2) * 2 99 | yy = np.random.randint(0, (H - cfg.image_height * 2 + 1) / 2) * 2 100 | 101 | scene_ind = data_name.split('/')[1].split('_')[0] 102 | frame_ind = int(data_name.split('/')[1].split('_')[1][5:]) 103 | iso_ind = data_name.split('/')[0] 104 | 105 | noisy_level_ind = iso_list.index(int(iso_ind[3:])) 106 | noisy_level = [a_list[noisy_level_ind], b_list[noisy_level_ind]] 107 | 108 | gt_name_list = [] 109 | noisy_name_list = [] 110 | xx_list = [] 111 | yy_list = [] 112 | for shift in range(0, cfg.frame_num): 113 | gt_name = os.path.join(cfg.data_root[1],'indoor_raw_gt/indoor_raw_gt_{}/{}/{}/frame{}_clean_and_slightly_denoised.tiff'.format( 114 | scene_ind,scene_ind,iso_ind,frame_list[frame_ind + shift])) 115 | 116 | noisy_frame_index_for_current = np.random.randint(0, 10) 117 | noisy_name = os.path.join(cfg.data_root[1], 118 | 'indoor_raw_noisy/indoor_raw_noisy_{}/{}/{}/frame{}_noisy{}.tiff'.format( 119 | scene_ind,scene_ind, iso_ind, frame_list[frame_ind + shift], noisy_frame_index_for_current)) 120 | 121 | gt_name_list.append(gt_name) 122 | noisy_name_list.append(noisy_name) 123 | 124 | xx_list.append(xx) 125 | yy_list.append(yy) 126 | 127 | gt_raw_data_list = list(map(read_img, gt_name_list, xx_list, yy_list)) 128 | noisy_data_list = list(map(read_img, noisy_name_list, xx_list, yy_list)) 129 | gt_raw_batch = np.concatenate(gt_raw_data_list, axis=2) 130 | noisy_raw_batch = np.concatenate(noisy_data_list, axis=2) 131 | 132 | return noisy_raw_batch, gt_raw_batch, noisy_level 133 | 134 | 135 | class loadImgs(Dataset): 136 | def __init__(self, filelist): 137 | self.filelist = filelist 138 | 139 | def __len__(self): 140 | return len(self.filelist) 141 | 142 | def __getitem__(self, item): 143 | self.data_name = self.filelist[item] 144 | image, label, noisy_level = decode_data(self.data_name) 145 | self.image = image 146 | self.label = label 147 | self.noisy_level = noisy_level 148 | return self.image, self.label, self.noisy_level 149 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | 6 | class ISP(nn.Module): 7 | 8 | def __init__(self): 9 | super(ISP, self).__init__() 10 | 11 | self.conv1_1 = nn.Conv2d(4, 32, kernel_size=3, stride=1, padding=1) 12 | self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 13 | self.pool1 = nn.MaxPool2d(kernel_size=2) 14 | 15 | self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) 16 | self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 17 | self.pool2 = nn.MaxPool2d(kernel_size=2) 18 | 19 | self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 20 | self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 21 | 22 | self.upv4 = nn.ConvTranspose2d(128, 64, 2, stride=2) 23 | self.conv4_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1) 24 | self.conv4_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 25 | 26 | self.upv5 = nn.ConvTranspose2d(64, 32, 2, stride=2) 27 | self.conv5_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1) 28 | self.conv5_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 29 | 30 | self.conv6_1 = nn.Conv2d(32, 12, kernel_size=1, stride=1) 31 | 32 | def forward(self, x): 33 | conv1 = self.lrelu(self.conv1_1(x)) 34 | conv1 = self.lrelu(self.conv1_2(conv1)) 35 | pool1 = self.pool1(conv1) 36 | 37 | conv2 = self.lrelu(self.conv2_1(pool1)) 38 | conv2 = self.lrelu(self.conv2_2(conv2)) 39 | pool2 = self.pool1(conv2) 40 | 41 | conv3 = self.lrelu(self.conv3_1(pool2)) 42 | conv3 = self.lrelu(self.conv3_2(conv3)) 43 | 44 | up4 = self.upv4(conv3) 45 | up4 = torch.cat([up4, conv2], 1) 46 | conv4 = self.lrelu(self.conv4_1(up4)) 47 | conv4 = self.lrelu(self.conv4_2(conv4)) 48 | 49 | up5 = self.upv5(conv4) 50 | up5 = torch.cat([up5, conv1], 1) 51 | conv5 = self.lrelu(self.conv5_1(up5)) 52 | conv5 = self.lrelu(self.conv5_2(conv5)) 53 | 54 | conv6 = self.conv6_1(conv5) 55 | out = nn.functional.pixel_shuffle(conv6, 2) 56 | return out 57 | 58 | def _initialize_weights(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | m.weight.data.normal_(0.0, 0.02) 62 | if m.bias is not None: 63 | m.bias.data.normal_(0.0, 0.02) 64 | if isinstance(m, nn.ConvTranspose2d): 65 | m.weight.data.normal_(0.0, 0.02) 66 | 67 | def lrelu(self, x): 68 | outt = torch.max(0.2 * x, x) 69 | return outt 70 | -------------------------------------------------------------------------------- /netloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class L1Loss(nn.Module): 7 | def __init__(self): 8 | super(L1Loss, self).__init__() 9 | 10 | def forward(self, predict, label): 11 | l1loss = torch.mean(torch.abs(predict - label)) 12 | return l1loss 13 | 14 | class PSNR(nn.Module): 15 | def __init__(self): 16 | super(PSNR, self).__init__() 17 | 18 | def forward(self, image, label): 19 | MSE = (image - label) * (image - label) 20 | MSE = torch.mean(MSE) 21 | PSNR = 10 * torch.log(1 / MSE) / torch.log(torch.Tensor([10.])).cuda() # torch.log is log base e 22 | 23 | return PSNR 24 | 25 | 26 | def loss_color(model, layers, device): # Color Transform 27 | ''' 28 | :param model: 29 | :param layers: layer name we want to use orthogonal regularization 30 | :param device: cpu or gpu 31 | :return: loss 32 | ''' 33 | loss_orth = torch.tensor(0., dtype = torch.float32, device = device) 34 | params = {} 35 | for name, param in model.named_parameters(): 36 | params[name] = param 37 | ct = params['ct.net1.weight'].squeeze() 38 | cti = params['cti.net1.weight'].squeeze() 39 | weight_squared = torch.matmul(ct, cti) 40 | diag = torch.eye(weight_squared.shape[0], dtype=torch.float32, device=device) 41 | loss = ((weight_squared - diag) **2).sum() 42 | loss_orth += loss 43 | return loss_orth 44 | 45 | def loss_wavelet(model, device): # Frequency Transform 46 | ''' 47 | :param model: 48 | :param device: cpu or gpu 49 | :return: loss 50 | ''' 51 | loss_orth = torch.tensor(0., dtype = torch.float32, device = device) 52 | params = {} 53 | for name, param in model.named_parameters(): 54 | params[name] = param 55 | ft = params['ft.net1.weight'].squeeze() 56 | fti = torch.cat([params['fti.net1.weight'],params['fti.net2.weight']],dim= 0).squeeze() 57 | weight_squared = torch.matmul(ft, fti) 58 | diag = torch.eye(weight_squared.shape[1], dtype=torch.float32, device=device) 59 | loss=((weight_squared - diag) **2).sum() 60 | loss_orth += loss 61 | return loss_orth 62 | -------------------------------------------------------------------------------- /structure.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import config as cfg 7 | 8 | device = cfg.device 9 | # device = 'cpu' 10 | 11 | cfa = np.array( 12 | [[0.5, 0.5, 0.5, 0.5], [-0.5, 0.5, 0.5, -0.5], [0.65, 0.2784, -0.2784, -0.65], [-0.2784, 0.65, -0.65, 0.2764]]) 13 | 14 | cfa = np.expand_dims(cfa, axis=2) 15 | cfa = np.expand_dims(cfa, axis=3) 16 | cfa = torch.tensor(cfa).float() # .cuda() 17 | cfa_inv = cfa.transpose(0, 1) 18 | 19 | # dwt dec 20 | h0 = np.array([1 / math.sqrt(2), 1 / math.sqrt(2)]) 21 | h1 = np.array([-1 / math.sqrt(2), 1 / math.sqrt(2)]) 22 | h0 = np.array(h0[::-1]).ravel() 23 | h1 = np.array(h1[::-1]).ravel() 24 | h0 = torch.tensor(h0).float().reshape((1, 1, -1)) 25 | h1 = torch.tensor(h1).float().reshape((1, 1, -1)) 26 | h0_col = h0.reshape((1, 1, -1, 1)) # col lowpass 27 | h1_col = h1.reshape((1, 1, -1, 1)) # col highpass 28 | h0_row = h0.reshape((1, 1, 1, -1)) # row lowpass 29 | h1_row = h1.reshape((1, 1, 1, -1)) # row highpass 30 | ll_filt = torch.cat([h0_row, h1_row], dim=0) 31 | 32 | # dwt rec 33 | g0 = np.array([1 / math.sqrt(2), 1 / math.sqrt(2)]) 34 | g1 = np.array([1 / math.sqrt(2), -1 / math.sqrt(2)]) 35 | g0 = np.array(g0).ravel() 36 | g1 = np.array(g1).ravel() 37 | g0 = torch.tensor(g0).float().reshape((1, 1, -1)) 38 | g1 = torch.tensor(g1).float().reshape((1, 1, -1)) 39 | g0_col = g0.reshape((1, 1, -1, 1)) 40 | g1_col = g1.reshape((1, 1, -1, 1)) 41 | g0_row = g0.reshape((1, 1, 1, -1)) 42 | g1_row = g1.reshape((1, 1, 1, -1)) 43 | 44 | 45 | class ColorTransfer(nn.Module): 46 | def __init__(self): 47 | super(ColorTransfer, self).__init__() 48 | self.net1 = nn.Conv2d(4, 4, kernel_size=1, stride=1, padding=0, bias=None) 49 | self.net1.weight = torch.nn.Parameter(cfa) 50 | 51 | def forward(self, x): 52 | out = self.net1(x) 53 | return out 54 | 55 | 56 | class ColorTransferInv(nn.Module): 57 | def __init__(self): 58 | super(ColorTransferInv, self).__init__() 59 | self.net1 = nn.Conv2d(4, 4, kernel_size=1, stride=1, padding=0, bias=None) 60 | self.net1.weight = torch.nn.Parameter(cfa_inv) 61 | 62 | def forward(self, x): 63 | out = self.net1(x) 64 | return out 65 | 66 | 67 | class FreTransfer(nn.Module): 68 | def __init__(self): 69 | super(FreTransfer, self).__init__() 70 | self.net1 = nn.Conv2d(1, 2, kernel_size=(1, 2), stride=(1, 2), padding=0, 71 | bias=None) # Cin = 1, Cout = 4, kernel_size = (1,2) 72 | self.net1.weight = torch.nn.Parameter(ll_filt) # torch.Size([2, 1, 1, 2]) 73 | 74 | def forward(self, x): 75 | B, C, H, W = x.shape 76 | ll = torch.ones([B, 4, int(H / 2), int(W / 2)], device=device) 77 | hl = torch.ones([B, 4, int(H / 2), int(W / 2)], device=device) 78 | lh = torch.ones([B, 4, int(H / 2), int(W / 2)], device=device) 79 | hh = torch.ones([B, 4, int(H / 2), int(W / 2)], device=device) 80 | 81 | for i in range(C): 82 | ll_ = self.net1(x[:, i:(i + 1) * 1, :, :]) # 1 * 2 * 128 * 64 83 | y = [] 84 | for j in range(2): 85 | weight = self.net1.weight.transpose(2, 3) 86 | y_out = F.conv2d(ll_[:, j:(j + 1) * 1, :, :], weight, stride=(2, 1), padding=0, bias=None) 87 | y.append(y_out) 88 | y_ = torch.cat([y[0], y[1]], dim=1) 89 | ll[:, i:(i + 1), :, :] = y_[:, 0:1, :, :] 90 | hl[:, i:(i + 1), :, :] = y_[:, 1:2, :, :] 91 | lh[:, i:(i + 1), :, :] = y_[:, 2:3, :, :] 92 | hh[:, i:(i + 1), :, :] = y_[:, 3:4, :, :] 93 | 94 | out = torch.cat([ll, hl, lh, hh], dim=1) 95 | return out 96 | 97 | 98 | class FreTransferInv(nn.Module): 99 | def __init__(self): 100 | super(FreTransferInv, self).__init__() 101 | self.net1 = nn.ConvTranspose2d(1, 1, kernel_size=(2, 1), stride=(2, 1), padding=0, bias=None) 102 | self.net1.weight = torch.nn.Parameter(g0_col) # torch.Size([1,1,2,1]) 103 | self.net2 = nn.ConvTranspose2d(1, 1, kernel_size=(2, 1), stride=(2, 1), padding=0, bias=None) 104 | self.net2.weight = torch.nn.Parameter(g1_col) # torch.Size([1,1,2,1]) 105 | 106 | def forward(self, x): 107 | lls = x[:, 0:4, :, :] 108 | hls = x[:, 4:8, :, :] 109 | lhs = x[:, 8:12, :, :] 110 | hhs = x[:, 12:16, :, :] 111 | B, C, H, W = lls.shape 112 | out = torch.ones([B, C, int(H * 2), int(W * 2)], device=device) 113 | for i in range(C): 114 | ll = lls[:, i:i + 1, :, :] 115 | hl = hls[:, i:i + 1, :, :] 116 | lh = lhs[:, i:i + 1, :, :] 117 | hh = hhs[:, i:i + 1, :, :] 118 | 119 | lo = self.net1(ll) + self.net2(hl) # 1 * 1 * 128 * 64 120 | hi = self.net1(lh) + self.net2(hh) # 1 * 1 * 128 * 64 121 | weight_l = self.net1.weight.transpose(2, 3) 122 | weight_h = self.net2.weight.transpose(2, 3) 123 | l = F.conv_transpose2d(lo, weight_l, stride=(1, 2), padding=0, bias=None) 124 | h = F.conv_transpose2d(hi, weight_h, stride=(1, 2), padding=0, bias=None) 125 | out[:, i:i + 1, :, :] = l + h 126 | return out 127 | 128 | 129 | class Fusion_down(nn.Module): 130 | def __init__(self): 131 | super(Fusion_down, self).__init__() 132 | self.net1 = nn.Conv2d(5, 16, kernel_size=3, stride=1, padding=1) 133 | self.net2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) 134 | self.net3 = nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1) 135 | 136 | def forward(self, x): 137 | net1 = F.relu(self.net1(x)) 138 | net2 = F.relu(self.net2(net1)) 139 | out = F.sigmoid(self.net3(net2)) 140 | return out 141 | 142 | 143 | class Fusion_up(nn.Module): 144 | def __init__(self): 145 | super(Fusion_up, self).__init__() 146 | self.net1 = nn.Conv2d(6, 16, kernel_size=3, stride=1, padding=1) 147 | self.net2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) 148 | self.net3 = nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1) 149 | 150 | def forward(self, x): 151 | net1 = F.relu(self.net1(x)) 152 | net2 = F.relu(self.net2(net1)) 153 | out = F.sigmoid(self.net3(net2)) 154 | return out 155 | 156 | 157 | class Denoise_down(nn.Module): 158 | 159 | def __init__(self): 160 | super(Denoise_down, self).__init__() 161 | self.net1 = nn.Conv2d(21, 16, kernel_size=3, stride=1, padding=1) 162 | self.net2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) 163 | self.net3 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) 164 | 165 | def forward(self, x): 166 | net1 = F.relu(self.net1(x)) 167 | net2 = F.relu(self.net2(net1)) 168 | out = self.net3(net2) 169 | return out 170 | 171 | 172 | class Denoise_up(nn.Module): 173 | 174 | def __init__(self): 175 | super(Denoise_up, self).__init__() 176 | self.net1 = nn.Conv2d(25, 16, kernel_size=3, stride=1, padding=1) 177 | self.net2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) 178 | self.net3 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) 179 | 180 | def forward(self, x): 181 | net1 = F.relu(self.net1(x)) 182 | net2 = F.relu(self.net2(net1)) 183 | out = self.net3(net2) 184 | return out 185 | 186 | 187 | class Refine(nn.Module): 188 | 189 | def __init__(self): 190 | super(Refine, self).__init__() 191 | self.net1 = nn.Conv2d(33, 16, kernel_size=3, stride=1, padding=1) 192 | self.net2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) 193 | self.net3 = nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1) 194 | 195 | def forward(self, x): 196 | net1 = F.relu(self.net1(x)) 197 | net2 = F.relu(self.net2(net1)) 198 | out = F.sigmoid(self.net3(net2)) 199 | return out 200 | 201 | 202 | class VideoDenoise(nn.Module): 203 | def __init__(self): 204 | super(VideoDenoise, self).__init__() 205 | 206 | self.fusion = Fusion_down() 207 | self.denoise = Denoise_down() 208 | 209 | def forward(self, ft0, ft1, coeff_a, coeff_b): 210 | ll0 = ft0[:, 0:4, :, :] 211 | ll1 = ft1[:, 0:4, :, :] 212 | 213 | # fusion 214 | sigma_ll1 = torch.clamp(ll1[:, 0:1, :, :], 0, 1) * coeff_a + coeff_b 215 | fusion_in = torch.cat([abs(ll1 - ll0), sigma_ll1], dim=1) 216 | gamma = self.fusion(fusion_in) 217 | fusion_out = torch.mul(ft0, (1 - gamma)) + torch.mul(ft1, gamma) 218 | 219 | # denoise 220 | sigma_ll0 = torch.clamp(ll0[:, 0:1, :, :], 0, 1) * coeff_a + coeff_b 221 | sigma = (1 - gamma) * (1 - gamma) * sigma_ll0 + gamma * gamma * sigma_ll1 222 | denoise_in = torch.cat([fusion_out, ll1, sigma], dim=1) 223 | denoise_out = self.denoise(denoise_in) 224 | return gamma, denoise_out 225 | 226 | 227 | class MultiVideoDenoise(nn.Module): 228 | def __init__(self): 229 | super(MultiVideoDenoise, self).__init__() 230 | self.fusion = Fusion_up() 231 | self.denoise = Denoise_up() 232 | 233 | def forward(self, ft0, ft1, gamma_up, denoise_down, coeff_a, coeff_b): 234 | ll0 = ft0[:, 0:4, :, :] 235 | ll1 = ft1[:, 0:4, :, :] 236 | 237 | # fusion 238 | sigma_ll1 = torch.clamp(ll1[:, 0:1, :, :], 0, 1) * coeff_a + coeff_b 239 | fusion_in = torch.cat([abs(ll1 - ll0), gamma_up, sigma_ll1], dim=1) 240 | gamma = self.fusion(fusion_in) 241 | fusion_out = torch.mul(ft0, (1 - gamma)) + torch.mul(ft1, gamma) 242 | 243 | # denoise 244 | sigma_ll0 = torch.clamp(ll0[:, 0:1, :, :], 0, 1) * coeff_a + coeff_b 245 | sigma = (1 - gamma) * (1 - gamma) * sigma_ll0 + gamma * gamma * sigma_ll1 246 | denoise_in = torch.cat([fusion_out, denoise_down, ll1, sigma], dim=1) 247 | denoise_out = self.denoise(denoise_in) 248 | 249 | return gamma, fusion_out, denoise_out, sigma 250 | 251 | 252 | class MainDenoise(nn.Module): 253 | def __init__(self): 254 | super(MainDenoise, self).__init__() 255 | self.ct = ColorTransfer() 256 | self.cti = ColorTransferInv() 257 | self.ft = FreTransfer() 258 | self.fti = FreTransferInv() 259 | self.vd = VideoDenoise() 260 | self.md1 = MultiVideoDenoise() 261 | self.md0 = MultiVideoDenoise() 262 | self.refine = Refine() 263 | 264 | def transform(self, x): 265 | net1 = self.ct(x) 266 | out = self.ft(net1) 267 | return out 268 | 269 | def transforminv(self, x): 270 | net1 = self.fti(x) 271 | out = self.cti(net1) 272 | return out 273 | 274 | def forward(self, x, coeff_a=1, coeff_b=1): 275 | ft0 = x[:, 0:4, :, :] # 1*4*128*128, the t-1 fusion frame 276 | ft1 = x[:, 4:8, :, :] # 1*4*128*128, the t frame 277 | 278 | ft0_d0 = self.transform(ft0) # scale0, torch.Size([1, 16, 256, 256]) 279 | ft1_d0 = self.transform(ft1) 280 | 281 | ft0_d1 = self.ft(ft0_d0[:,0:4,:,:]) # scale1,torch.Size([1, 16, 128, 128]) 282 | ft1_d1 = self.ft(ft1_d0[:, 0:4, :, :]) 283 | 284 | ft0_d2 = self.ft(ft0_d1[:,0:4,:,:]) # scale2, torch.Size([1, 16, 64, 64]) 285 | ft1_d2 = self.ft(ft1_d1[:, 0:4, :, :]) 286 | 287 | 288 | gamma, denoise_out = self.vd(ft0_d2, ft1_d2, coeff_a, coeff_b) 289 | denoise_out_d2 = self.fti(denoise_out) 290 | gamma_up_d2 = F.upsample(gamma, scale_factor=2) 291 | 292 | 293 | gamma, fusion_out, denoise_out, sigma = self.md1(ft0_d1, ft1_d1, gamma_up_d2, denoise_out_d2, coeff_a, coeff_b) 294 | denoise_up_d1 = self.fti(denoise_out) 295 | gamma_up_d1 = F.upsample(gamma, scale_factor=2) 296 | 297 | gamma, fusion_out, denoise_out, sigma = self.md0(ft0_d0, ft1_d0, gamma_up_d1, denoise_up_d1, coeff_a, coeff_b) 298 | 299 | # refine 300 | refine_in = torch.cat([fusion_out, denoise_out, sigma], axis=1) # 1 * 36 * 128 * 128 301 | omega = self.refine(refine_in) # 1 * 16 * 128 * 128 302 | refine_out = torch.mul(denoise_out, (1 - omega)) + torch.mul(fusion_out, omega) 303 | 304 | fusion_out = self.transforminv(fusion_out) 305 | refine_out = self.transforminv(refine_out) 306 | denoise_out = self.transforminv(denoise_out) 307 | 308 | return gamma, fusion_out, denoise_out, omega, refine_out 309 | 310 | 311 | 312 | -------------------------------------------------------------------------------- /torchsummary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | 8 | ''' 9 | Licensed... 10 | pytorch-summary is MIT-licensed. 11 | Using... 12 | summary(your_model, input_size=(channels, H, W)) 13 | ''' 14 | 15 | def summary(model, input_size, batch_size = -1, device = torch.device('cuda:0'), dtypes = None): 16 | result, params_info = summary_string( 17 | model, input_size, batch_size, device, dtypes 18 | ) 19 | print (result) 20 | 21 | return params_info 22 | 23 | def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): 24 | if dtypes == None: 25 | dtypes = [torch.FloatTensor]*len(input_size) 26 | 27 | summary_str = '' 28 | 29 | def register_hook(module): 30 | def hook(module, input, output): 31 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 32 | module_idx = len(summary) 33 | 34 | m_key = "%s-%i" % (class_name, module_idx + 1) 35 | summary[m_key] = OrderedDict() 36 | summary[m_key]["input_shape"] = list(input[0].size()) 37 | summary[m_key]["input_shape"][0] = batch_size 38 | if isinstance(output, (list, tuple)): 39 | summary[m_key]["output_shape"] = [ 40 | [-1] + list(o.size())[1:] for o in output 41 | ] 42 | else: 43 | summary[m_key]["output_shape"] = list(output.size()) 44 | summary[m_key]["output_shape"][0] = batch_size 45 | 46 | params = 0 47 | # compute the nums of parameters 48 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 49 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 50 | summary[m_key]["trainable"] = module.weight.requires_grad 51 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 52 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 53 | summary[m_key]["nb_params"] = params 54 | 55 | if ( 56 | not isinstance(module, nn.Sequential) 57 | and not isinstance(module, nn.ModuleList) 58 | ): 59 | hooks.append(module.register_forward_hook(hook)) 60 | 61 | # multiple inputs to the network 62 | if isinstance(input_size, tuple): 63 | input_size = [input_size] 64 | 65 | # batch_size of 2 for batchnorm 66 | x = [torch.rand(2, *in_size).type(dtype).to(device=device) 67 | for in_size, dtype in zip(input_size, dtypes)] 68 | 69 | # create properties 70 | summary = OrderedDict() 71 | hooks = [] 72 | 73 | # register hook 74 | model.apply(register_hook) 75 | 76 | # make a forward pass 77 | # print(x.shape) 78 | model(*x) 79 | 80 | # remove these hooks 81 | for h in hooks: 82 | h.remove() 83 | 84 | summary_str += "----------------------------------------------------------------" + "\n" 85 | line_new = "{:>20} {:>25} {:>15}".format( 86 | "Layer (type)", "Output Shape", "Param #") 87 | summary_str += line_new + "\n" 88 | summary_str += "================================================================" + "\n" 89 | total_params = 0 90 | total_output = 0 91 | trainable_params = 0 92 | for layer in summary: 93 | # input_shape, output_shape, trainable, nb_params 94 | line_new = "{:>20} {:>25} {:>15}".format( 95 | layer, 96 | str(summary[layer]["output_shape"]), 97 | "{0:,}".format(summary[layer]["nb_params"]), 98 | ) 99 | total_params += summary[layer]["nb_params"] 100 | 101 | total_output += np.prod(summary[layer]["output_shape"]) 102 | if "trainable" in summary[layer]: 103 | if summary[layer]["trainable"] == True: 104 | trainable_params += summary[layer]["nb_params"] 105 | summary_str += line_new + "\n" 106 | 107 | # assume 4 bytes/number (float on cuda). 108 | total_input_size = abs(np.prod(sum(input_size, ())) 109 | * batch_size * 4. / (1024 ** 2.)) 110 | total_output_size = abs(2. * total_output * 4. / 111 | (1024 ** 2.)) # x2 for gradients 112 | total_params_size = abs(total_params * 4. / (1024 ** 2.)) 113 | total_size = total_params_size + total_output_size + total_input_size 114 | 115 | summary_str += "================================================================" + "\n" 116 | summary_str += "Total params: {0:,}".format(total_params) + "\n" 117 | summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n" 118 | summary_str += "Non-trainable params: {0:,}".format(total_params - 119 | trainable_params) + "\n" 120 | summary_str += "----------------------------------------------------------------" + "\n" 121 | summary_str += "Input size (MB): %0.2f" % total_input_size + "\n" 122 | summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n" 123 | summary_str += "Params size (MB): %0.2f" % total_params_size + "\n" 124 | summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n" 125 | summary_str += "----------------------------------------------------------------" + "\n" 126 | # return summary 127 | return summary_str, (total_params, trainable_params) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | from tensorboardX import SummaryWriter 3 | from torchvision.utils import make_grid 4 | import shutil 5 | from PIL import ImageFile 6 | ImageFile.LOAD_TRUNCATED_IMAGES = True 7 | import os 8 | import cv2 9 | import warnings 10 | warnings.filterwarnings('ignore') 11 | from torchstat import stat 12 | 13 | import utils 14 | from dataset import * 15 | import config as cfg 16 | import structure as structure 17 | import netloss as netloss 18 | from load_data import * 19 | import time 20 | 21 | iso_list = [1600, 3200, 6400, 12800, 25600] 22 | a_list = [3.513262, 6.955588, 13.486051, 26.585953, 52.032536] 23 | b_list = [11.917691, 38.117816, 130.818508, 484.539790, 1819.818657] 24 | 25 | def initialize(): 26 | """ 27 | # clear some dir if necessary 28 | make some dir if necessary 29 | make sure training from scratch 30 | :return: 31 | """ 32 | ## 33 | if not os.path.exists(cfg.model_name): 34 | os.mkdir(cfg.model_name) 35 | 36 | if not os.path.exists(cfg.debug_dir): 37 | os.mkdir(cfg.debug_dir) 38 | 39 | if not os.path.exists(cfg.log_dir): 40 | os.mkdir(cfg.log_dir) 41 | 42 | 43 | if cfg.checkpoint == None: 44 | s = input('Are you sure training the model from scratch? y/n \n') 45 | if not (s=='y'): 46 | return 47 | 48 | 49 | def duplicate_output_to_log(name): 50 | tee = utils.Tee(name) 51 | return tee 52 | 53 | 54 | def train(in_data, gt_raw_data, noisy_level, model, loss, device, optimizer): 55 | l1loss_list = [] 56 | l1loss_total = 0 57 | coeff_a = (noisy_level[0] / (2 ** 12 - 1 - 240)).float().to(device) 58 | coeff_a = coeff_a[:,None,None,None] 59 | coeff_b = (noisy_level[1] / (2 ** 12 - 1 - 240) ** 2).float().to(device) 60 | coeff_b = coeff_b[:, None, None, None] 61 | for time_ind in range(cfg.frame_num): 62 | ft1 = in_data[:, time_ind * 4: (time_ind + 1) * 4, :, :] # the t-th input frame 63 | fgt = gt_raw_data[:, time_ind * 4: (time_ind + 1) * 4, :, :] # the t-th gt frame 64 | if time_ind == 0: 65 | ft0_fusion = ft1 66 | else: 67 | ft0_fusion = ft0_fusion_data # the t-1 fusion frame 68 | 69 | input = torch.cat([ft0_fusion, ft1], dim=1) 70 | 71 | model.train() 72 | gamma, fusion_out, denoise_out, omega, refine_out = model(input, coeff_a, coeff_b) 73 | loss_refine = loss(refine_out, fgt) 74 | loss_fusion = loss(fusion_out, fgt) 75 | loss_denoise = loss(denoise_out, fgt) 76 | 77 | l1loss = loss_refine 78 | 79 | l1loss_list.append(l1loss) 80 | l1loss_total += l1loss 81 | 82 | ft0_fusion_data = fusion_out 83 | 84 | loss_ct = netloss.loss_color(model, ['ct.net1.weight', 'cti.net1.weight'], device) 85 | loss_ft = netloss.loss_wavelet(model, device) 86 | total_loss = l1loss_total / (cfg.frame_num) + loss_ct + loss_ft 87 | 88 | 89 | optimizer.zero_grad() 90 | total_loss.backward() 91 | torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=5, norm_type=2) 92 | optimizer.step() 93 | 94 | print('Loss | ', ('%.8f' % total_loss.item())) 95 | del in_data, gt_raw_data 96 | return ft1, fgt, refine_out, fusion_out, denoise_out, gamma, omega, total_loss, loss_ct, loss_ft, loss_fusion, loss_denoise 97 | 98 | def evaluate(model, psnr, writer, iter): 99 | print('Evaluate...') 100 | cnt = 0 101 | total_psnr = 0 102 | total_psnr_raw = 0 103 | model.eval() 104 | with torch.no_grad(): 105 | for scene_ind in range(7,9): 106 | for noisy_level in range(0,5): 107 | in_data, gt_raw_data = load_eval_data(noisy_level, scene_ind) 108 | frame_psnr = 0 109 | frame_psnr_raw = 0 110 | for time_ind in range(cfg.frame_num): 111 | ft1 = in_data[:, time_ind * 4: (time_ind + 1) * 4, :, :] 112 | fgt = gt_raw_data[:, time_ind * 4: (time_ind + 1) * 4, :, :] 113 | if time_ind == 0: 114 | ft0_fusion = ft1 115 | else: 116 | ft0_fusion = ft0_fusion_data 117 | 118 | coeff_a = a_list[noisy_level] / (2 ** 12 - 1 - 240) 119 | coeff_b = b_list[noisy_level] / (2 ** 12 - 1 - 240) ** 2 120 | input = torch.cat([ft0_fusion, ft1], dim=1) 121 | 122 | gamma, fusion_out, denoise_out, omega, refine_out = model(input, coeff_a, coeff_b) 123 | 124 | ft0_fusion_data = fusion_out 125 | 126 | frame_psnr += psnr(refine_out, fgt) 127 | frame_psnr_raw += psnr(ft1, fgt) 128 | 129 | frame_psnr = frame_psnr / (cfg.frame_num) 130 | frame_psnr_raw = frame_psnr_raw / (cfg.frame_num) 131 | print('---------') 132 | print('Scene: ', ('%02d' % scene_ind), 'Noisy_level: ', ('%02d' % noisy_level), 'PSNR: ', '%.8f' % frame_psnr.item()) 133 | total_psnr += frame_psnr 134 | total_psnr_raw += frame_psnr_raw 135 | cnt += 1 136 | del in_data, gt_raw_data 137 | total_psnr = total_psnr / cnt 138 | total_psnr_raw = total_psnr_raw / cnt 139 | print('Eval_Total_PSNR | ', ('%.8f' % total_psnr.item())) 140 | writer.add_scalar('PSNR', total_psnr.item(), iter) 141 | writer.add_scalar('PSNR_RAW', total_psnr_raw.item(), iter) 142 | writer.add_scalar('PSNR_IMP', total_psnr.item() - total_psnr_raw.item(), iter) 143 | torch.cuda.empty_cache() 144 | return total_psnr, total_psnr_raw 145 | 146 | def main(): 147 | """ 148 | Train, Valid, Write Log, Write Predict ,etc 149 | :return: 150 | """ 151 | checkpoint = cfg.checkpoint 152 | start_epoch = cfg.start_epoch 153 | start_iter = cfg.start_iter 154 | best_psnr = 0 155 | 156 | ## use gpu 157 | device = cfg.device 158 | ngpu = cfg.ngpu 159 | cudnn.benchmark = True 160 | 161 | ## tensorboard --logdir runs 162 | writer = SummaryWriter(cfg.log_dir) 163 | 164 | ## initialize model 165 | model = structure.MainDenoise() 166 | 167 | ## compute GFLOPs 168 | # stat(model, (8,512,512)) 169 | 170 | model = model.to(device) 171 | loss = netloss.L1Loss().to(device) 172 | psnr = netloss.PSNR().to(device) 173 | 174 | learning_rate = cfg.learning_rate 175 | optimizer = torch.optim.Adam(params = filter(lambda p: p.requires_grad, model.parameters()), lr = learning_rate) 176 | 177 | ## load pretrained model 178 | if checkpoint is not None: 179 | print('--- Loading Pretrained Model ---') 180 | checkpoint = torch.load(checkpoint) 181 | start_epoch = checkpoint['epoch'] 182 | start_iter = checkpoint['iter'] 183 | model.load_state_dict(checkpoint['model']) 184 | optimizer.load_state_dict(checkpoint['optimizer']) 185 | iter = start_iter 186 | 187 | if torch.cuda.is_available() and ngpu > 1: 188 | model = nn.DataParallel(model, device_ids=list(range(ngpu))) 189 | 190 | shutil.copy('structure.py', os.path.join(cfg.model_name)) 191 | shutil.copy('train.py', os.path.join(cfg.model_name)) 192 | shutil.copy('netloss.py', os.path.join(cfg.model_name)) 193 | 194 | train_data_name_queue = generate_file_list(['1', '2', '3', '4', '5', '6']) 195 | train_dataset = loadImgs(train_data_name_queue) 196 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = cfg.batch_size, num_workers = cfg.num_workers, shuffle = True, pin_memory = True) 197 | 198 | eval_data_name_queue = generate_file_list(['7', '8']) 199 | eval_dataset = loadImgs(eval_data_name_queue) 200 | eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size = cfg.batch_size, num_workers = cfg.num_workers, shuffle = True, pin_memory = True) 201 | 202 | for epoch in range(start_epoch, cfg.epoch): 203 | print('------------------------------------------------') 204 | print('Epoch | ', ('%08d' % epoch)) 205 | for i, (input, label, noisy_level) in enumerate(train_loader): 206 | print('------------------------------------------------') 207 | print('Iter | ', ('%08d' % iter)) 208 | in_data = input.permute(0, 3, 1, 2).to(device) 209 | gt_raw_data = label.permute(0, 3, 1, 2).to(device) 210 | 211 | ft1, fgt, refine_out, fusion_out, denoise_out, gamma, omega, \ 212 | total_loss, loss_ct, loss_ft, loss_fusion, loss_denoise = train(in_data, gt_raw_data, noisy_level, model, loss, device, optimizer) 213 | iter = iter + 1 214 | if iter % cfg.log_step == 0: 215 | input_gray = torch.mean(ft1, 1, True) 216 | label_gray = torch.mean(fgt, 1, True) 217 | predict_gray = torch.mean(refine_out, 1, True) 218 | fusion_gray = torch.mean(fusion_out, 1, True) 219 | denoise_gray = torch.mean(denoise_out, 1, True) 220 | gamma_gray = torch.mean(gamma[:, 0:1, :, :], 1, True) 221 | omega_gray = torch.mean(omega[:, 0:1, :, :], 1, True) 222 | 223 | writer.add_image('input', make_grid(input_gray.cpu(), nrow=4, normalize=True), iter) 224 | writer.add_image('fusion_out', make_grid(fusion_gray.cpu(), nrow=4, normalize=True), iter) 225 | writer.add_image('denoise_out', make_grid(denoise_gray.cpu(), nrow=4, normalize=True), iter) 226 | writer.add_image('refine_out', make_grid(predict_gray.cpu(), nrow=4, normalize=True), iter) 227 | writer.add_image('label', make_grid(label_gray.cpu(), nrow=4, normalize=True), iter) 228 | 229 | writer.add_image('gamma', make_grid(gamma_gray.cpu(), nrow=4, normalize=True), iter) 230 | writer.add_image('omega', make_grid(omega_gray.cpu(), nrow=4, normalize=True), iter) 231 | 232 | writer.add_scalar('L1Loss', total_loss.item(), iter) 233 | writer.add_scalar('L1Color', loss_ct.item(), iter) 234 | writer.add_scalar('L1Wavelet', loss_ft.item(), iter) 235 | writer.add_scalar('L1Denoise', loss_denoise.item(), iter) 236 | writer.add_scalar('L1Fusion', loss_fusion.item(), iter) 237 | 238 | torch.save({ 239 | 'epoch': epoch, 240 | 'iter': iter, 241 | 'model': model.state_dict(), 242 | 'optimizer': optimizer.state_dict()}, 243 | cfg.model_save_root) 244 | 245 | if iter % cfg.valid_step == 0 and iter > cfg.valid_start_iter: 246 | eval_psnr, eval_psnr_raw = evaluate(model, psnr, writer, iter) 247 | 248 | if eval_psnr>best_psnr: 249 | best_psnr = eval_psnr 250 | torch.save({ 251 | 'epoch': epoch, 252 | 'iter': iter, 253 | 'model': model.state_dict(), 254 | 'optimizer': optimizer.state_dict(), 255 | 'best_psnr': best_psnr}, 256 | os.path.join(cfg.model_name, 'model_best.pth')) 257 | writer.close() 258 | 259 | 260 | if __name__ == '__main__': 261 | initialize() 262 | main() 263 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | class Tee(object): 4 | def __init__(self, name): 5 | 6 | self.file = open(name, 'w') 7 | self.stdout = sys.stdout 8 | self.stderr = sys.stderr 9 | sys.stdout = self 10 | sys.stderr = self 11 | 12 | def __del__(self): 13 | self.file.close() 14 | 15 | def write(self,data): 16 | self.file.write(data) 17 | self.stdout.write(data) 18 | self.file.flush() 19 | self.stdout.flush() 20 | 21 | def write_to_file(self,data): 22 | self.file.write(data) 23 | 24 | class AverageMeter(object): 25 | """Computes and stores the average and current value""" 26 | def __init__(self): 27 | self.reset() 28 | 29 | def reset(self): 30 | self.val = 0 31 | self.avg = 0 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | 42 | def print_params(name): 43 | f = open(name, 'r', encoding='utf-8') 44 | for line in f.readlines(): 45 | if line == '\n': 46 | continue 47 | print(line) 48 | f.close() --------------------------------------------------------------------------------