├── LICENSE ├── assets └── model.png ├── data └── prepare_data.md ├── datasets ├── __init__.py ├── lu.py ├── middlebury.py ├── noisy_middlebury.py └── nyu.py ├── main.py ├── models ├── __init__.py ├── djf.py ├── dkn.py ├── edsr.py └── jiif.py ├── readme.md ├── scripts ├── test_denoise_jiif.sh ├── test_denoise_jiif_pretrained.sh ├── test_jiif.sh ├── test_jiif_pretrained.sh ├── train_denoise_jiif.sh └── train_jiif.sh └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 hawkey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/jiif/acfa3f99b20d087ee8215dd13a71a8ffea5067c0/assets/model.png -------------------------------------------------------------------------------- /data/prepare_data.md: -------------------------------------------------------------------------------- 1 | # data preparation 2 | 3 | ### NYU 4 | We use a [preprocessed version](https://drive.google.com/drive/folders/1_1HpmoCsshNCMQdXhSNOq8Y-deIDcbKS?usp=sharing) provided [here](https://github.com/charlesCXK/RGBD_Semantic_Segmentation_PyTorch#data-preparation). Just download the file and extract it to `data/nyu_labeled` to use. 5 | 6 | The official NYU Depth V2 data can be downloaded [here](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html). If you prefer to use the official data, you need to extract the depth and RGB images from the mat file and save them to `data/nyu_labeled/Depth/*.npy` and `data/nyu_labeled/RGB/*.jpg` respectively. 7 | 8 | 9 | 10 | ### MiddleBury & Lu 11 | 12 | For these two datasets, we follow [Su et al. (Depth Enhancement via Low-rank Matrix Completion)](http://web.cecs.pdx.edu/~fliu/project/depth-enhance/) and use the data provided [here](http://web.cecs.pdx.edu/~fliu/project/depth-enhance/Depth_Enh.zip). Download it and extract it to `data/depth_enhance` to use. 13 | 14 | ### 15 | 16 | ### Noisy MiddleBury 17 | 18 | For the three images (`Art, Books, Moebius`) used in the noisy super-resolution experiment, we download the RGB images from the official [middlebury 2005 datasets site](https://vision.middlebury.edu/stereo/data/scenes2005/). For the GT depth, we follow [Park et al. (High Quality Depth Map Upsampling for 3D-TOF Cameras)](http://jaesik.info/publications/depthups/index.html) and use the data provided [here](http://jaesik.info/publications/depthups/iccv11_dataset.zip). Download the RGBs (view1) and the GT depths, then put them under `data/noisy_depth/middlebury/rgb/` and `data/noisy_depth/middlebury/gt/` respectively. Also, the file names should be modified to match each pair of RGB and GT depth. 19 | 20 | We also provide a copy of the processed data [here](https://drive.google.com/file/d/1Bz0NcFdRzjN2CnWZlJzNSBFOBsnRZNKE/view?usp=sharing). Download it and extract to `data/noisy_depth/middlebury` to use. 21 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .lu import LuDataset 2 | from .middlebury import MiddleburyDataset 3 | from .nyu import NYUDataset 4 | from .noisy_middlebury import NoisyMiddleburyDataset -------------------------------------------------------------------------------- /datasets/lu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | import numpy as np 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | 10 | import os 11 | import glob 12 | import random 13 | from PIL import Image 14 | import tqdm 15 | 16 | from utils import make_coord, to_pixel_samples, visualize_2d 17 | 18 | class LuDataset(Dataset): 19 | def __init__(self, root, split='test', scale=8, augment=True, downsample='bicubic', pre_upsample=False, to_pixel=False, sample_q=None, input_size=None): 20 | super().__init__() 21 | self.root = root 22 | self.split = split 23 | self.scale = scale 24 | self.augment = augment 25 | self.downsample = downsample 26 | self.pre_upsample = pre_upsample 27 | self.to_pixel = to_pixel 28 | self.sample_q = sample_q 29 | self.input_size = input_size 30 | 31 | if self.split == 'train': 32 | raise AttributeError('Lu dataset only support test mode.') 33 | else: 34 | self.image_files = sorted(glob.glob(os.path.join(root, '*ouput_color*'))) # the name escape a `t`... 35 | self.depth_files = sorted(glob.glob(os.path.join(root, '*output_depth*'))) 36 | assert len(self.image_files) == len(self.depth_files) 37 | self.size = len(self.image_files) 38 | 39 | def __getitem__(self, idx): 40 | 41 | image_file = self.image_files[idx] 42 | depth_file = self.depth_files[idx] 43 | 44 | image = cv2.imread(image_file).astype(np.uint8) # [H, W, 3] 45 | 46 | depth_hr = cv2.imread(depth_file)[:,:,0].astype(np.float32) # [H, W] 47 | depth_min = depth_hr.min() 48 | depth_max = depth_hr.max() 49 | depth_hr = (depth_hr - depth_min) / (depth_max - depth_min) 50 | 51 | # crop after rescale 52 | if self.input_size is not None: 53 | x0 = random.randint(0, image.shape[0] - self.input_size) 54 | y0 = random.randint(0, image.shape[1] - self.input_size) 55 | image = image[x0:x0+self.input_size, y0:y0+self.input_size] 56 | depth_hr = depth_hr[x0:x0+self.input_size, y0:y0+self.input_size] 57 | 58 | h, w = image.shape[:2] 59 | 60 | if self.downsample == 'bicubic': 61 | depth_lr = np.array(Image.fromarray(depth_hr).resize((w//self.scale, h//self.scale), Image.BICUBIC)) 62 | image_lr = np.array(Image.fromarray(image).resize((w//self.scale, h//self.scale), Image.BICUBIC)) 63 | elif self.downsample == 'nearest-right-bottom': 64 | depth_lr = depth_hr[(self.scale - 1)::self.scale, (self.scale - 1)::self.scale] 65 | image_lr = image[(self.scale - 1)::self.scale, (self.scale - 1)::self.scale] 66 | elif self.downsample == 'nearest-center': 67 | depth_lr = np.array(Image.fromarray(depth_hr).resize((w//self.scale, h//self.scale), Image.NEAREST)) 68 | image_lr = np.array(Image.fromarray(image).resize((w//self.scale, h//self.scale), Image.NEAREST)) 69 | elif self.downsample == 'nearest-left-top': 70 | depth_lr = depth_hr[::self.scale, ::self.scale] 71 | image_lr = image[::self.scale, ::self.scale] 72 | else: 73 | raise NotImplementedError 74 | 75 | image = image.astype(np.float32).transpose(2,0,1) / 255 76 | image_lr = image_lr.astype(np.float32).transpose(2,0,1) / 255 # [3, H, W] 77 | 78 | image = (image - np.array([0.485, 0.456, 0.406]).reshape(3,1,1)) / np.array([0.229, 0.224, 0.225]).reshape(3,1,1) 79 | image_lr = (image_lr - np.array([0.485, 0.456, 0.406]).reshape(3,1,1)) / np.array([0.229, 0.224, 0.225]).reshape(3,1,1) 80 | 81 | # follow DKN, use bicubic upsampling of PIL 82 | depth_lr_up = np.array(Image.fromarray(depth_lr).resize((w, h), Image.BICUBIC)) 83 | 84 | if self.pre_upsample: 85 | depth_lr = depth_lr_up 86 | 87 | # to tensor 88 | image = torch.from_numpy(image).float() 89 | image_lr = torch.from_numpy(image_lr).float() 90 | depth_hr = torch.from_numpy(depth_hr).unsqueeze(0).float() 91 | depth_lr = torch.from_numpy(depth_lr).unsqueeze(0).float() 92 | depth_lr_up = torch.from_numpy(depth_lr_up).unsqueeze(0).float() 93 | 94 | # transform 95 | if self.augment: 96 | hflip = random.random() < 0.5 97 | vflip = random.random() < 0.5 98 | 99 | def augment(x): 100 | if hflip: 101 | x = x.flip(-2) 102 | if vflip: 103 | x = x.flip(-1) 104 | return x 105 | 106 | image = augment(image) 107 | image_lr = augment(image_lr) 108 | depth_hr = augment(depth_hr) 109 | depth_lr = augment(depth_lr) 110 | depth_lr_up = augment(depth_lr_up) 111 | 112 | image = image.contiguous() 113 | image_lr = image_lr.contiguous() 114 | depth_hr = depth_hr.contiguous() 115 | depth_lr = depth_lr.contiguous() 116 | depth_lr_up = depth_lr_up.contiguous() 117 | 118 | # to pixel 119 | if self.to_pixel: 120 | 121 | hr_coord, hr_pixel = to_pixel_samples(depth_hr) 122 | 123 | lr_pixel = depth_lr_up.view(-1, 1) 124 | 125 | if self.sample_q is not None: 126 | sample_lst = np.random.choice(len(hr_coord), self.sample_q, replace=False) 127 | hr_coord = hr_coord[sample_lst] 128 | hr_pixel = hr_pixel[sample_lst] 129 | lr_pixel = lr_pixel[sample_lst] 130 | 131 | cell = torch.ones_like(hr_coord) 132 | cell[:, 0] *= 2 / depth_hr.shape[-2] 133 | cell[:, 1] *= 2 / depth_hr.shape[-1] 134 | 135 | return { 136 | 'image': image, 137 | 'lr_image': image_lr, 138 | 'lr': depth_lr, 139 | 'hr': hr_pixel, 140 | 'hr_depth': depth_hr, 141 | 'lr_pixel': lr_pixel, 142 | 'hr_coord': hr_coord, 143 | 'min': depth_min, 144 | 'max': depth_max, 145 | 'cell': cell, 146 | 'idx': idx, 147 | } 148 | 149 | 150 | else: 151 | return { 152 | 'image': image, 153 | 'lr': depth_lr, 154 | 'hr': depth_hr, 155 | 'min': depth_min, 156 | 'max': depth_max, 157 | 'idx': idx, 158 | } 159 | 160 | def __len__(self): 161 | return self.size 162 | 163 | 164 | if __name__ == '__main__': 165 | print('===== test direct bicubic upsampling =====') 166 | for method in ['bicubic']: 167 | for scale in [8]: 168 | print(f'[INFO] scale = {scale}, method = {method}') 169 | d = LuDataset(root='data/depth_enhance/03_RGBD_Dataset', split='test', pre_upsample=True, augment=False, scale=scale, downsample=method) 170 | rmses = [] 171 | for i in tqdm.trange(len(d)): 172 | x = d[i] 173 | lr = ((x['lr'].numpy() * (x['max'] - x['min'])) + x['min']) 174 | hr = ((x['hr'].numpy() * (x['max'] - x['min'])) + x['min']) 175 | rmse = np.sqrt(np.mean(np.power(lr - hr, 2))) 176 | rmses.append(rmse) 177 | print('RMSE = ', np.mean(rmses)) -------------------------------------------------------------------------------- /datasets/middlebury.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | import numpy as np 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | 10 | import os 11 | import glob 12 | import random 13 | from PIL import Image 14 | import tqdm 15 | 16 | from utils import make_coord, to_pixel_samples, visualize_2d, add_noise 17 | 18 | class MiddleburyDataset(Dataset): 19 | def __init__(self, root, split='test', scale=8, augment=True, downsample='bicubic', pre_upsample=False, to_pixel=False, sample_q=None, input_size=None, noisy=False): 20 | super().__init__() 21 | self.root = root 22 | self.split = split 23 | self.scale = scale 24 | self.augment = augment 25 | self.downsample = downsample 26 | self.pre_upsample = pre_upsample 27 | self.to_pixel = to_pixel 28 | self.sample_q = sample_q 29 | self.input_size = input_size 30 | self.noisy = noisy 31 | 32 | if self.split == 'train': 33 | raise AttributeError('Middlebury dataset only support test mode.') 34 | else: 35 | self.image_files = sorted(glob.glob(os.path.join(root, '*output_color*'))) 36 | self.depth_files = sorted(glob.glob(os.path.join(root, '*output_depth*'))) 37 | assert len(self.image_files) == len(self.depth_files) 38 | self.size = len(self.image_files) 39 | 40 | def __getitem__(self, idx): 41 | 42 | image_file = self.image_files[idx] 43 | depth_file = self.depth_files[idx] 44 | 45 | image = cv2.imread(image_file).astype(np.uint8) # [H, W, 3] 46 | 47 | depth_hr = cv2.imread(depth_file)[:,:,0].astype(np.float32) # [H, W] 48 | depth_min = depth_hr.min() 49 | depth_max = depth_hr.max() 50 | depth_hr = (depth_hr - depth_min) / (depth_max - depth_min) 51 | 52 | # crop to make divisible 53 | h, w = image.shape[:2] 54 | h = h - h % self.scale 55 | w = w - w % self.scale 56 | image = image[:h, :w] 57 | depth_hr = depth_hr[:h, :w] 58 | 59 | 60 | # crop after rescale 61 | if self.input_size is not None: 62 | x0 = random.randint(0, image.shape[0] - self.input_size) 63 | y0 = random.randint(0, image.shape[1] - self.input_size) 64 | image = image[x0:x0+self.input_size, y0:y0+self.input_size] 65 | depth_hr = depth_hr[x0:x0+self.input_size, y0:y0+self.input_size] 66 | 67 | h, w = image.shape[:2] 68 | 69 | if self.downsample == 'bicubic': 70 | depth_lr = np.array(Image.fromarray(depth_hr).resize((w//self.scale, h//self.scale), Image.BICUBIC)) 71 | image_lr = np.array(Image.fromarray(image).resize((w//self.scale, h//self.scale), Image.BICUBIC)) 72 | elif self.downsample == 'nearest-right-bottom': 73 | depth_lr = depth_hr[(self.scale - 1)::self.scale, (self.scale - 1)::self.scale] 74 | image_lr = image[(self.scale - 1)::self.scale, (self.scale - 1)::self.scale] 75 | elif self.downsample == 'nearest-center': 76 | depth_lr = np.array(Image.fromarray(depth_hr).resize((w//self.scale, h//self.scale), Image.NEAREST)) 77 | image_lr = np.array(Image.fromarray(image).resize((w//self.scale, h//self.scale), Image.NEAREST)) 78 | elif self.downsample == 'nearest-left-top': 79 | depth_lr = depth_hr[::self.scale, ::self.scale] 80 | image_lr = image[::self.scale, ::self.scale] 81 | else: 82 | raise NotImplementedError 83 | 84 | if self.noisy: 85 | print(depth_lr.min(), depth_lr.max()) 86 | depth_lr = add_noise(depth_lr, sigma=0.01) 87 | 88 | image = image.astype(np.float32).transpose(2,0,1) / 255 89 | image_lr = image_lr.astype(np.float32).transpose(2,0,1) / 255 # [3, H, W] 90 | 91 | image = (image - np.array([0.485, 0.456, 0.406]).reshape(3,1,1)) / np.array([0.229, 0.224, 0.225]).reshape(3,1,1) 92 | image_lr = (image_lr - np.array([0.485, 0.456, 0.406]).reshape(3,1,1)) / np.array([0.229, 0.224, 0.225]).reshape(3,1,1) 93 | 94 | # follow DKN, use bicubic upsampling of PIL 95 | depth_lr_up = np.array(Image.fromarray(depth_lr).resize((w, h), Image.BICUBIC)) 96 | 97 | if self.pre_upsample: 98 | depth_lr = depth_lr_up 99 | 100 | # to tensor 101 | image = torch.from_numpy(image).float() 102 | image_lr = torch.from_numpy(image_lr).float() 103 | depth_hr = torch.from_numpy(depth_hr).unsqueeze(0).float() 104 | depth_lr = torch.from_numpy(depth_lr).unsqueeze(0).float() 105 | depth_lr_up = torch.from_numpy(depth_lr_up).unsqueeze(0).float() 106 | 107 | # transform 108 | if self.augment: 109 | hflip = random.random() < 0.5 110 | vflip = random.random() < 0.5 111 | 112 | def augment(x): 113 | if hflip: 114 | x = x.flip(-2) 115 | if vflip: 116 | x = x.flip(-1) 117 | return x 118 | 119 | image = augment(image) 120 | image_lr = augment(image_lr) 121 | depth_hr = augment(depth_hr) 122 | depth_lr = augment(depth_lr) 123 | depth_lr_up = augment(depth_lr_up) 124 | 125 | image = image.contiguous() 126 | image_lr = image_lr.contiguous() 127 | depth_hr = depth_hr.contiguous() 128 | depth_lr = depth_lr.contiguous() 129 | depth_lr_up = depth_lr_up.contiguous() 130 | 131 | # to pixel 132 | if self.to_pixel: 133 | 134 | hr_coord, hr_pixel = to_pixel_samples(depth_hr) 135 | 136 | lr_pixel = depth_lr_up.view(-1, 1) 137 | 138 | if self.sample_q is not None: 139 | sample_lst = np.random.choice(len(hr_coord), self.sample_q, replace=False) 140 | hr_coord = hr_coord[sample_lst] 141 | hr_pixel = hr_pixel[sample_lst] 142 | lr_pixel = lr_pixel[sample_lst] 143 | 144 | cell = torch.ones_like(hr_coord) 145 | cell[:, 0] *= 2 / depth_hr.shape[-2] 146 | cell[:, 1] *= 2 / depth_hr.shape[-1] 147 | 148 | return { 149 | 'image': image, 150 | 'lr_image': image_lr, 151 | 'lr': depth_lr, 152 | 'hr': hr_pixel, 153 | 'hr_depth': depth_hr, 154 | 'lr_pixel': lr_pixel, 155 | 'hr_coord': hr_coord, 156 | 'min': depth_min, 157 | 'max': depth_max, 158 | 'cell': cell, 159 | 'idx': idx, 160 | } 161 | 162 | else: 163 | return { 164 | 'image': image, 165 | 'lr': depth_lr, 166 | 'hr': depth_hr, 167 | 'min': depth_min, 168 | 'max': depth_max, 169 | 'idx': idx, 170 | } 171 | 172 | 173 | def __len__(self): 174 | return self.size 175 | 176 | 177 | if __name__ == '__main__': 178 | print('===== test direct bicubic upsampling =====') 179 | for method in ['bicubic']: 180 | for scale in [8]: 181 | print(f'[INFO] scale = {scale}, method = {method}') 182 | d = MiddleburyDataset(root='./data/depth_enhance/01_Middlebury_Dataset', split='test', pre_upsample=True, augment=False, scale=scale, downsample=method, noisy=False) 183 | rmses = [] 184 | for i in tqdm.trange(len(d)): 185 | x = d[i] 186 | lr = ((x['lr'].numpy() * (x['max'] - x['min'])) + x['min']) 187 | hr = ((x['hr'].numpy() * (x['max'] - x['min'])) + x['min']) 188 | rmse = np.sqrt(np.mean(np.power(lr - hr, 2))) 189 | rmses.append(rmse) 190 | print('RMSE = ', np.mean(rmses)) -------------------------------------------------------------------------------- /datasets/noisy_middlebury.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | import numpy as np 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | 10 | import os 11 | import glob 12 | import random 13 | from PIL import Image 14 | import tqdm 15 | 16 | from utils import make_coord, to_pixel_samples, visualize_2d, add_noise, seed_everything 17 | 18 | class NoisyMiddleburyDataset(Dataset): 19 | def __init__(self, root, split='test', scale=8, augment=True, downsample='bicubic', pre_upsample=False, to_pixel=False, sample_q=None, input_size=None, noisy=True): 20 | super().__init__() 21 | self.root = root 22 | self.split = split 23 | self.scale = scale 24 | self.augment = augment 25 | self.downsample = downsample 26 | self.pre_upsample = pre_upsample 27 | self.to_pixel = to_pixel 28 | self.sample_q = sample_q 29 | self.input_size = input_size 30 | self.noisy = noisy 31 | 32 | if self.split == 'train': 33 | raise AttributeError('Middlebury dataset only support test mode.') 34 | else: 35 | self.image_files = sorted(glob.glob(os.path.join(root, 'rgb/*.png'))) 36 | self.depth_files = sorted(glob.glob(os.path.join(root, 'gt/*.png'))) 37 | assert len(self.image_files) == len(self.depth_files) 38 | self.size = len(self.image_files) 39 | 40 | def __getitem__(self, idx): 41 | 42 | image_file = self.image_files[idx] 43 | depth_file = self.depth_files[idx] 44 | 45 | image = cv2.imread(image_file) # [H, W, 3] 46 | 47 | depth_hr = cv2.imread(depth_file)[:,:,0].astype(np.float32) # [H, W] 48 | 49 | # crop to make divisible 50 | image = image[11:-11, 7:-7] 51 | h, w = image.shape[:2] 52 | 53 | # crop after rescale 54 | if self.input_size is not None: 55 | x0 = random.randint(0, image.shape[0] - self.input_size) 56 | y0 = random.randint(0, image.shape[1] - self.input_size) 57 | image = image[x0:x0+self.input_size, y0:y0+self.input_size] 58 | depth_hr = depth_hr[x0:x0+self.input_size, y0:y0+self.input_size] 59 | 60 | h, w = image.shape[:2] 61 | 62 | if self.downsample == 'bicubic': 63 | depth_lr = np.array(Image.fromarray(depth_hr).resize((w//self.scale, h//self.scale), Image.BICUBIC)) 64 | image_lr = np.array(Image.fromarray(image).resize((w//self.scale, h//self.scale), Image.BICUBIC)) 65 | elif self.downsample == 'nearest-right-bottom': 66 | depth_lr = depth_hr[(self.scale - 1)::self.scale, (self.scale - 1)::self.scale] 67 | image_lr = image[(self.scale - 1)::self.scale, (self.scale - 1)::self.scale] 68 | elif self.downsample == 'nearest-center': 69 | depth_lr = np.array(Image.fromarray(depth_hr).resize((w//self.scale, h//self.scale), Image.NEAREST)) 70 | image_lr = np.array(Image.fromarray(image).resize((w//self.scale, h//self.scale), Image.NEAREST)) 71 | elif self.downsample == 'nearest-left-top': 72 | depth_lr = depth_hr[::self.scale, ::self.scale] 73 | image_lr = image[::self.scale, ::self.scale] 74 | else: 75 | raise NotImplementedError 76 | 77 | if self.noisy: 78 | #print(depth_lr.min(), depth_lr.max()) 79 | depth_lr = add_noise(depth_lr, sigma=651) 80 | 81 | # normalize 82 | depth_min = depth_hr.min() 83 | depth_max = depth_hr.max() 84 | depth_hr = (depth_hr - depth_min) / (depth_max - depth_min) 85 | depth_lr = (depth_lr - depth_min) / (depth_max - depth_min) 86 | 87 | image = image.astype(np.float32).transpose(2,0,1) / 255 88 | image_lr = image_lr.astype(np.float32).transpose(2,0,1) / 255 # [3, H, W] 89 | 90 | image = (image - np.array([0.485, 0.456, 0.406]).reshape(3,1,1)) / np.array([0.229, 0.224, 0.225]).reshape(3,1,1) 91 | image_lr = (image_lr - np.array([0.485, 0.456, 0.406]).reshape(3,1,1)) / np.array([0.229, 0.224, 0.225]).reshape(3,1,1) 92 | 93 | # follow DKN, use bicubic upsampling of PIL 94 | depth_lr_up = np.array(Image.fromarray(depth_lr).resize((w, h), Image.BICUBIC)) 95 | 96 | if self.pre_upsample: 97 | depth_lr = depth_lr_up 98 | 99 | # to tensor 100 | image = torch.from_numpy(image).float() 101 | image_lr = torch.from_numpy(image_lr).float() 102 | depth_hr = torch.from_numpy(depth_hr).unsqueeze(0).float() 103 | depth_lr = torch.from_numpy(depth_lr).unsqueeze(0).float() 104 | depth_lr_up = torch.from_numpy(depth_lr_up).unsqueeze(0).float() 105 | 106 | # transform 107 | if self.augment: 108 | hflip = random.random() < 0.5 109 | vflip = random.random() < 0.5 110 | 111 | def augment(x): 112 | if hflip: 113 | x = x.flip(-2) 114 | if vflip: 115 | x = x.flip(-1) 116 | return x 117 | 118 | image = augment(image) 119 | image_lr = augment(image_lr) 120 | depth_hr = augment(depth_hr) 121 | depth_lr = augment(depth_lr) 122 | depth_lr_up = augment(depth_lr_up) 123 | 124 | image = image.contiguous() 125 | image_lr = image_lr.contiguous() 126 | depth_hr = depth_hr.contiguous() 127 | depth_lr = depth_lr.contiguous() 128 | depth_lr_up = depth_lr_up.contiguous() 129 | 130 | # to pixel 131 | if self.to_pixel: 132 | 133 | hr_coord, hr_pixel = to_pixel_samples(depth_hr) 134 | 135 | lr_pixel = depth_lr_up.view(-1, 1) 136 | 137 | if self.sample_q is not None: 138 | sample_lst = np.random.choice(len(hr_coord), self.sample_q, replace=False) 139 | hr_coord = hr_coord[sample_lst] 140 | hr_pixel = hr_pixel[sample_lst] 141 | lr_pixel = lr_pixel[sample_lst] 142 | 143 | cell = torch.ones_like(hr_coord) 144 | cell[:, 0] *= 2 / depth_hr.shape[-2] 145 | cell[:, 1] *= 2 / depth_hr.shape[-1] 146 | 147 | return { 148 | 'image': image, 149 | 'lr_image': image_lr, 150 | 'lr': depth_lr, 151 | 'hr': hr_pixel, 152 | 'hr_depth': depth_hr, 153 | 'lr_pixel': lr_pixel, 154 | 'hr_coord': hr_coord, 155 | 'min': depth_min, 156 | 'max': depth_max, 157 | 'cell': cell, 158 | 'idx': idx, 159 | } 160 | 161 | else: 162 | return { 163 | 'image': image, 164 | 'lr': depth_lr, 165 | 'hr': depth_hr, 166 | 'min': depth_min, 167 | 'max': depth_max, 168 | 'idx': idx, 169 | } 170 | 171 | 172 | def __len__(self): 173 | return self.size 174 | 175 | 176 | if __name__ == '__main__': 177 | seed_everything(0) 178 | print('===== test direct bicubic upsampling =====') 179 | for method in ['bicubic']: 180 | for scale in [4, 8, 16]: 181 | print(f'[INFO] scale = {scale}, method = {method}') 182 | d = NoisyMiddleburyDataset(root='./data/noisy_depth/middlebury', split='test', pre_upsample=True, augment=False, scale=scale, downsample=method, noisy=True) 183 | rmses = [] 184 | for i in tqdm.trange(len(d)): 185 | x = d[i] 186 | lr = ((x['lr'].numpy() * (x['max'] - x['min'])) + x['min']) 187 | hr = ((x['hr'].numpy() * (x['max'] - x['min'])) + x['min']) 188 | rmse = np.sqrt(np.mean(np.power(lr - hr, 2))) 189 | rmses.append(rmse) 190 | print(rmse) -------------------------------------------------------------------------------- /datasets/nyu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | import numpy as np 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | 10 | import os 11 | import glob 12 | import random 13 | from PIL import Image 14 | import tqdm 15 | 16 | from utils import make_coord, add_noise 17 | 18 | def to_pixel_samples(depth): 19 | """ Convert the image to coord-RGB pairs. 20 | depth: Tensor, (1, H, W) 21 | """ 22 | coord = make_coord(depth.shape[-2:], flatten=True) # [H*W, 2] 23 | pixel = depth.view(-1, 1) # [H*W, 1] 24 | return coord, pixel 25 | 26 | class NYUDataset(Dataset): 27 | def __init__(self, root='/data3/tang/nyu_labeled', split='train', scale=8, augment=True, downsample='bicubic', pre_upsample=False, to_pixel=False, sample_q=None, input_size=None, noisy=False): 28 | super().__init__() 29 | self.root = root 30 | self.split = split 31 | self.scale = scale 32 | self.augment = augment 33 | self.downsample = downsample 34 | self.pre_upsample = pre_upsample 35 | self.to_pixel = to_pixel 36 | self.sample_q = sample_q 37 | self.input_size = input_size 38 | self.noisy = noisy 39 | 40 | # use the first 1000 data as training split 41 | if self.split == 'train': 42 | self.size = 1000 43 | else: 44 | self.size = 449 45 | 46 | def __getitem__(self, idx): 47 | if self.split != 'train': 48 | idx += 1000 49 | 50 | image_file = os.path.join(self.root, 'RGB', f'{idx}.jpg') 51 | depth_file = os.path.join(self.root, 'Depth', f'{idx}.npy') 52 | 53 | image = cv2.imread(image_file) # [H, W, 3] 54 | depth_hr = np.load(depth_file) # [H, W] 55 | 56 | # crop after rescale 57 | if self.input_size is not None: 58 | x0 = random.randint(0, image.shape[0] - self.input_size) 59 | y0 = random.randint(0, image.shape[1] - self.input_size) 60 | image = image[x0:x0+self.input_size, y0:y0+self.input_size] 61 | depth_hr = depth_hr[x0:x0+self.input_size, y0:y0+self.input_size] 62 | 63 | 64 | h, w = image.shape[:2] 65 | 66 | if self.downsample == 'bicubic': 67 | depth_lr = np.array(Image.fromarray(depth_hr).resize((w//self.scale, h//self.scale), Image.BICUBIC)) # bicubic, RMSE=7.13 68 | image_lr = np.array(Image.fromarray(image).resize((w//self.scale, h//self.scale), Image.BICUBIC)) # bicubic, RMSE=7.13 69 | #depth_lr = cv2.resize(depth_hr, (w//self.scale, h//self.scale), interpolation=cv2.INTER_CUBIC) # RMSE=8.03, cv2.resize is different from Image.resize. 70 | elif self.downsample == 'nearest-right-bottom': 71 | depth_lr = depth_hr[(self.scale - 1)::self.scale, (self.scale - 1)::self.scale] # right-bottom, RMSE=14.22, finally reproduced it... 72 | image_lr = image[(self.scale - 1)::self.scale, (self.scale - 1)::self.scale] # right-bottom, RMSE=14.22, finally reproduced it... 73 | elif self.downsample == 'nearest-center': 74 | depth_lr = np.array(Image.fromarray(depth_hr).resize((w//self.scale, h//self.scale), Image.NEAREST)) # center (if even, prefer right-bottom), RMSE=8.21 75 | image_lr = np.array(Image.fromarray(image).resize((w//self.scale, h//self.scale), Image.NEAREST)) # center (if even, prefer right-bottom), RMSE=8.21 76 | elif self.downsample == 'nearest-left-top': 77 | depth_lr = depth_hr[::self.scale, ::self.scale] # left-top, RMSE=13.94 78 | image_lr = image[::self.scale, ::self.scale] # left-top, RMSE=13.94 79 | else: 80 | raise NotImplementedError 81 | 82 | if self.noisy: 83 | depth_lr = add_noise(depth_lr, sigma=0.04, inv=False) 84 | 85 | # normalize 86 | depth_min = depth_hr.min() 87 | depth_max = depth_hr.max() 88 | depth_hr = (depth_hr - depth_min) / (depth_max - depth_min) 89 | depth_lr = (depth_lr - depth_min) / (depth_max - depth_min) 90 | 91 | image = image.astype(np.float32).transpose(2,0,1) / 255 92 | image_lr = image_lr.astype(np.float32).transpose(2,0,1) / 255 # [3, H, W] 93 | 94 | image = (image - np.array([0.485, 0.456, 0.406]).reshape(3,1,1)) / np.array([0.229, 0.224, 0.225]).reshape(3,1,1) 95 | image_lr = (image_lr - np.array([0.485, 0.456, 0.406]).reshape(3,1,1)) / np.array([0.229, 0.224, 0.225]).reshape(3,1,1) 96 | 97 | # follow DKN, use bicubic upsampling of PIL 98 | depth_lr_up = np.array(Image.fromarray(depth_lr).resize((w, h), Image.BICUBIC)) 99 | 100 | if self.pre_upsample: 101 | depth_lr = depth_lr_up 102 | 103 | # to tensor 104 | image = torch.from_numpy(image).float() 105 | image_lr = torch.from_numpy(image_lr).float() 106 | depth_hr = torch.from_numpy(depth_hr).unsqueeze(0).float() 107 | depth_lr = torch.from_numpy(depth_lr).unsqueeze(0).float() 108 | depth_lr_up = torch.from_numpy(depth_lr_up).unsqueeze(0).float() 109 | 110 | # transform 111 | if self.augment: 112 | hflip = random.random() < 0.5 113 | vflip = random.random() < 0.5 114 | 115 | def augment(x): 116 | if hflip: 117 | x = x.flip(-2) 118 | if vflip: 119 | x = x.flip(-1) 120 | return x 121 | 122 | image = augment(image) 123 | image_lr = augment(image_lr) 124 | depth_hr = augment(depth_hr) 125 | depth_lr = augment(depth_lr) 126 | depth_lr_up = augment(depth_lr_up) 127 | 128 | image = image.contiguous() 129 | image_lr = image_lr.contiguous() 130 | depth_hr = depth_hr.contiguous() 131 | depth_lr = depth_lr.contiguous() 132 | depth_lr_up = depth_lr_up.contiguous() 133 | 134 | # to pixel 135 | if self.to_pixel: 136 | 137 | hr_coord, hr_pixel = to_pixel_samples(depth_hr) 138 | 139 | lr_pixel = depth_lr_up.view(-1, 1) 140 | 141 | if self.sample_q is not None: 142 | sample_lst = np.random.choice(len(hr_coord), self.sample_q, replace=False) 143 | hr_coord = hr_coord[sample_lst] 144 | hr_pixel = hr_pixel[sample_lst] 145 | lr_pixel = lr_pixel[sample_lst] 146 | 147 | cell = torch.ones_like(hr_coord) 148 | cell[:, 0] *= 2 / depth_hr.shape[-2] 149 | cell[:, 1] *= 2 / depth_hr.shape[-1] 150 | 151 | return { 152 | 'image': image, 153 | 'lr_image': image_lr, 154 | 'lr': depth_lr, 155 | 'hr': hr_pixel, 156 | 'hr_depth': depth_hr, 157 | 'lr_pixel': lr_pixel, 158 | 'hr_coord': hr_coord, 159 | 'min': depth_min * 100, 160 | 'max': depth_max * 100, 161 | 'cell': cell, 162 | 'idx': idx, 163 | } 164 | 165 | else: 166 | return { 167 | 'image': image, 168 | 'lr': depth_lr, 169 | 'hr': depth_hr, 170 | 'min': depth_min * 100, 171 | 'max': depth_max * 100, 172 | 'idx': idx, 173 | } 174 | 175 | def __len__(self): 176 | return self.size 177 | 178 | 179 | if __name__ == '__main__': 180 | print('===== test direct bicubic upsampling =====') 181 | for method in ['bicubic']: 182 | for scale in [4, 8, 16]: 183 | print(f'[INFO] scale = {scale}, method = {method}') 184 | d = NYUDataset(root='/data3/tang/nyu_labeled', split='test', pre_upsample=True, augment=False, scale=scale, downsample=method, noisy=False) 185 | #d = NYUDataset(root='/data3/tang/nyu_labeled', split='test', pre_upsample=True, augment=False, scale=scale, downsample=method, noisy=True) 186 | rmses = [] 187 | for i in tqdm.trange(len(d)): 188 | x = d[i] 189 | lr = ((x['lr'].numpy() * (x['max'] - x['min'])) + x['min']) 190 | hr = ((x['hr'].numpy() * (x['max'] - x['min'])) + x['min']) 191 | rmse = np.sqrt(np.mean(np.power(lr - hr, 2))) 192 | rmses.append(rmse) 193 | print('RMSE = ', np.mean(rmses)) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | 7 | from utils import * 8 | from datasets import * 9 | from models import * 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--name', type=str, default='jiif') 13 | parser.add_argument('--model', type=str, default='JIIF') 14 | parser.add_argument('--loss', type=str, default='L1') 15 | parser.add_argument('--seed', type=int, default=0) 16 | parser.add_argument('--dataset', type=str, default='NYU') 17 | parser.add_argument('--data_root', type=str, default='./data/nyu_labeled/') 18 | parser.add_argument('--train_batch', type=int, default=1) 19 | parser.add_argument('--test_batch', type=int, default=1) 20 | parser.add_argument('--num_workers', type=int, default=8) 21 | parser.add_argument('--epoch', default=100, type=int, help='max epoch') 22 | parser.add_argument('--eval_interval', default=10, type=int, help='eval interval') 23 | parser.add_argument('--checkpoint', default='scratch', type=str, help='checkpoint to use') 24 | parser.add_argument('--scale', default=8, type=int, help='scale') 25 | parser.add_argument('--interpolation', default='bicubic', type=str, help='interpolation method to generate lr depth') 26 | parser.add_argument('--lr', default=0.0001, type=float, help='learning rate') 27 | parser.add_argument('--lr_step', default=40, type=float, help='learning rate decay step') 28 | parser.add_argument('--lr_gamma', default=0.2, type=float, help='learning rate decay gamma') 29 | parser.add_argument('--input_size', default=None, type=int, help='crop size for hr image') 30 | parser.add_argument('--sample_q', default=30720, type=int, help='sampled pixels per hr depth') 31 | parser.add_argument('--noisy', action='store_true', help='add noise to train dataset') 32 | parser.add_argument('--test', action='store_true', help='test mode') 33 | parser.add_argument('--report_per_image', action='store_true', help='report RMSE of each image') 34 | parser.add_argument('--save', action='store_true', help='save results') 35 | parser.add_argument('--batched_eval', action='store_true', help='batched evaluation to avoid OOM for large image resolution') 36 | 37 | args = parser.parse_args() 38 | 39 | seed_everything(args.seed) 40 | 41 | # model 42 | if args.model == 'DKN': 43 | model = DKN(kernel_size=3, filter_size=15, residual=True) 44 | elif args.model == 'FDKN': 45 | model = FDKN(kernel_size=3, filter_size=15, residual=True) 46 | elif args.model == 'DJF': 47 | model = DJF(residual=True) 48 | elif args.model == 'JIIF': 49 | model = JIIF(args, 128, 128) 50 | else: 51 | raise NotImplementedError(f'Model {args.model} not found') 52 | 53 | # loss 54 | if args.loss == 'L1': 55 | criterion = nn.L1Loss() 56 | elif args.loss == 'L2': 57 | criterion = nn.MSELoss() 58 | else: 59 | raise NotImplementedError(f'Loss {args.loss} not found') 60 | 61 | # dataset 62 | if args.dataset == 'NYU': 63 | dataset = NYUDataset 64 | elif args.dataset == 'Lu': 65 | dataset = LuDataset 66 | elif args.dataset == 'Middlebury': 67 | dataset = MiddleburyDataset 68 | elif args.dataset == 'NoisyMiddlebury': 69 | dataset = NoisyMiddleburyDataset 70 | else: 71 | raise NotImplementedError(f'Dataset {args.loss} not found') 72 | 73 | if args.model in ['JIIF']: 74 | if not args.test: 75 | train_dataset = dataset(root=args.data_root, split='train', scale=args.scale, downsample=args.interpolation, augment=True, to_pixel=True, sample_q=args.sample_q, input_size=args.input_size, noisy=args.noisy) 76 | test_dataset = dataset(root=args.data_root, split='test', scale=args.scale, downsample=args.interpolation, augment=False, to_pixel=True, sample_q=None) # full image 77 | elif args.model in ['DJF', 'DKN', 'FDKN']: 78 | if not args.test: 79 | train_dataset = dataset(root=args.data_root, split='train', scale=args.scale, downsample=args.interpolation, augment=True, pre_upsample=True, input_size=args.input_size, noisy=args.noisy) 80 | test_dataset = dataset(root=args.data_root, split='test', scale=args.scale, downsample=args.interpolation, augment=False, pre_upsample=True) 81 | else: 82 | raise NotImplementedError(f'Dataset for model type {args.model} not found') 83 | 84 | # dataloader 85 | if not args.test: 86 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, pin_memory=True, drop_last=False, shuffle=True, num_workers=args.num_workers) 87 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch, pin_memory=True, drop_last=False, shuffle=False, num_workers=args.num_workers) 88 | 89 | # trainer 90 | if not args.test: 91 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 92 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) 93 | trainer = Trainer(args, args.name, model, objective=criterion, optimizer=optimizer, lr_scheduler=scheduler, metrics=[RMSEMeter(args)], device='cuda', use_checkpoint=args.checkpoint, eval_interval=args.eval_interval) 94 | else: 95 | trainer = Trainer(args, args.name, model, objective=criterion, metrics=[RMSEMeter(args)], device='cuda', use_checkpoint=args.checkpoint) 96 | 97 | # main 98 | if not args.test: 99 | trainer.train(train_loader, test_loader, args.epoch) 100 | 101 | if args.save: 102 | # save results (doesn't need GT) 103 | trainer.test(test_loader) 104 | else: 105 | # evaluate (needs GT) 106 | trainer.evaluate(test_loader) 107 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dkn import DKN, FDKN 2 | from .djf import DJF 3 | from .jiif import JIIF -------------------------------------------------------------------------------- /models/djf.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/ZQPei/deep_joint_filter 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class BaseNetwork(nn.Module): 8 | def __init__(self): 9 | super(BaseNetwork, self).__init__() 10 | 11 | def init_weights(self, init_type='normal', gain=0.02): 12 | ''' 13 | initialize network's weights 14 | init_type: normal | xavier | kaiming | orthogonal 15 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 16 | ''' 17 | 18 | def init_func(m): 19 | classname = m.__class__.__name__ 20 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 21 | if init_type == 'normal': 22 | nn.init.normal_(m.weight.data, 0.0, gain) 23 | elif init_type == 'xavier': 24 | nn.init.xavier_normal_(m.weight.data, gain=gain) 25 | elif init_type == 'kaiming': 26 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 27 | elif init_type == 'orthogonal': 28 | nn.init.orthogonal_(m.weight.data, gain=gain) 29 | 30 | if hasattr(m, 'bias') and m.bias is not None: 31 | nn.init.constant_(m.bias.data, 0.0) 32 | 33 | elif classname.find('BatchNorm2d') != -1: 34 | nn.init.normal_(m.weight.data, 1.0, gain) 35 | nn.init.constant_(m.bias.data, 0.0) 36 | 37 | self.apply(init_func) 38 | 39 | 40 | class CNN(BaseNetwork): 41 | def __init__(self, num_conv=3, c_in=1, channel=[96,48,1], kernel_size=[9,1,5], stride=[1,1,1], padding=[2,2,2]): 42 | super(CNN, self).__init__() 43 | 44 | layers = [] 45 | for i in range(num_conv): 46 | layers += [nn.Conv2d(c_in if i == 0 else channel[i-1], channel[i], kernel_size[i], stride[i], padding[i], bias=True)] 47 | if i != num_conv-1: 48 | layers += [nn.ReLU(inplace=True)] 49 | 50 | self.feature = nn.Sequential(*layers) 51 | 52 | self.init_weights() 53 | 54 | 55 | def forward(self, x): 56 | fmap = self.feature(x) 57 | return fmap 58 | 59 | 60 | # called DJFR if residual = True 61 | class DJF(BaseNetwork): 62 | def __init__(self, init_weights=True, residual=True): 63 | super().__init__() 64 | self.residual = residual 65 | 66 | self.cnn_t = CNN(c_in=1, channel=[96, 48, 1]) 67 | self.cnn_g = CNN(c_in=3, channel=[96, 48, 1]) 68 | self.cnn_f = CNN(c_in=2, channel=[64, 32, 1]) 69 | 70 | if init_weights: 71 | self.init_weights() 72 | 73 | 74 | def forward(self, data): 75 | 76 | image = data['image'] 77 | depth = data['lr'] 78 | 79 | fmap1 = self.cnn_t(depth) 80 | fmap2 = self.cnn_g(image) 81 | 82 | output = self.cnn_f(torch.cat([fmap1, fmap2], dim=1)) 83 | 84 | if self.residual: 85 | output = output + depth 86 | 87 | return output 88 | -------------------------------------------------------------------------------- /models/dkn.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/cvlab-yonsei/dkn 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | def grid_generator(k, r, n): 8 | grid_x, grid_y = torch.meshgrid([torch.linspace(k//2, k//2+r-1, steps=r), 9 | torch.linspace(k//2, k//2+r-1, steps=r)]) 10 | grid = torch.stack([grid_x,grid_y],2).view(r,r,2) 11 | 12 | return grid.unsqueeze(0).repeat(n,1,1,1).cuda() 13 | 14 | 15 | class Kernel_DKN(nn.Module): 16 | def __init__(self, input_channel, kernel_size): 17 | super(Kernel_DKN, self).__init__() 18 | self.conv1 = nn.Conv2d(input_channel, 32, 7) 19 | self.conv1_bn = nn.BatchNorm2d(32) 20 | self.conv2 = nn.Conv2d(32, 32, 2, stride=(2,2)) 21 | self.conv3 = nn.Conv2d(32, 64, 5) 22 | self.conv3_bn = nn.BatchNorm2d(64) 23 | self.conv4 = nn.Conv2d(64, 64, 2, stride=(2,2)) 24 | self.conv5 = nn.Conv2d(64, 128, 5) 25 | self.conv5_bn = nn.BatchNorm2d(128) 26 | self.conv6 = nn.Conv2d(128, 128, 3) 27 | self.conv7 = nn.Conv2d(128, 128, 3) 28 | 29 | self.conv_weight = nn.Conv2d(128, kernel_size**2, 1) 30 | self.conv_offset = nn.Conv2d(128, 2*kernel_size**2, 1) 31 | 32 | def forward(self, x): 33 | x = F.relu(self.conv1_bn(self.conv1(x))) 34 | x = F.relu(self.conv2(x)) 35 | x = F.relu(self.conv3_bn(self.conv3(x))) 36 | x = F.relu(self.conv4(x)) 37 | x = F.relu(self.conv5_bn(self.conv5(x))) 38 | x = F.relu(self.conv6(x)) 39 | x = F.relu(self.conv7(x)) 40 | 41 | offset = self.conv_offset(x) 42 | weight = torch.sigmoid(self.conv_weight(x)) 43 | 44 | return weight, offset 45 | 46 | class DKN(nn.Module): 47 | def __init__(self, kernel_size, filter_size, residual=True): 48 | super(DKN, self).__init__() 49 | self.ImageKernel = Kernel_DKN(input_channel=3, kernel_size=kernel_size) 50 | self.DepthKernel = Kernel_DKN(input_channel=1, kernel_size=kernel_size) 51 | self.residual = residual 52 | self.kernel_size = kernel_size 53 | self.filter_size = filter_size 54 | 55 | def forward(self, data): 56 | 57 | image = data['image'] 58 | depth = data['lr'] 59 | 60 | ### DKN assumes depth is normalized to [0, 1] and resized to full resolution. 61 | b, _, h, w = image.shape 62 | 63 | weight, offset = self._shift_and_stitch(image, depth) 64 | 65 | k = self.filter_size 66 | r = self.kernel_size 67 | hw = h*w 68 | 69 | # weighted average 70 | # (b, 2*r**2, h, w) -> (b*hw, r, r, 2) 71 | offset = offset.permute(0,2,3,1).contiguous().view(b*hw, r,r, 2) 72 | # (b, r**2, h, w) -> (b*hw, r**2, 1) 73 | weight = weight.permute(0,2,3,1).contiguous().view(b*hw, r*r, 1) 74 | 75 | # (b*hw, r, r, 2) 76 | grid = grid_generator(k, r, b*hw) 77 | 78 | coord = grid + offset 79 | coord = (coord / k * 2) -1 80 | 81 | # (b, k**2, hw) -> (b*hw, 1, k, k) 82 | depth_col = F.unfold(depth, k, padding=k//2).permute(0,2,1).contiguous().view(b*hw, 1, k,k) 83 | 84 | # (b*hw, 1, k, k), (b*hw, r, r, 2) => (b*hw, 1, r^2) 85 | depth_sampled = F.grid_sample(depth_col, coord, align_corners=False).view(b*hw, 1, -1) 86 | 87 | # (b*w*h, 1, r^2) x (b*w*h, r^2, 1) => (b, 1, h,w) 88 | out = torch.bmm(depth_sampled, weight).view(b, 1, h,w) 89 | 90 | if self.residual: 91 | out += depth 92 | 93 | return out 94 | 95 | def _infer(self, image, depth): 96 | 97 | imkernel, imoffset = self.ImageKernel(image) 98 | depthkernel, depthoffset = self.DepthKernel(depth) 99 | 100 | weight = imkernel * depthkernel 101 | offset = imoffset * depthoffset 102 | 103 | if self.residual: 104 | weight -= torch.mean(weight, 1).unsqueeze(1).expand_as(weight) 105 | else: 106 | weight /= torch.sum(weight, 1).unsqueeze(1).expand_as(weight) 107 | 108 | return weight, offset 109 | 110 | def _shift_and_stitch(self, image, depth): 111 | 112 | offset = torch.zeros((image.size(0), 2*self.kernel_size**2, image.size(2), image.size(3)), 113 | dtype=image.dtype, layout=image.layout, device=image.device) 114 | weight = torch.zeros((image.size(0), self.kernel_size**2, image.size(2), image.size(3)), 115 | dtype=image.dtype, layout=image.layout, device=image.device) 116 | 117 | for i in range(4): 118 | for j in range(4): 119 | 120 | m = nn.ZeroPad2d((25-j,22+j,25-i,22+i)) 121 | m = nn.ZeroPad2d((25-j,22+j,25-i,22+i)) 122 | 123 | img_shift = m(image) 124 | depth_shift = m(depth) 125 | 126 | w, o = self._infer(img_shift, depth_shift) 127 | 128 | weight[:,:,i::4,j::4] = w 129 | offset[:,:,i::4,j::4] = o 130 | 131 | return weight, offset 132 | 133 | 134 | def resample_data(input, s): 135 | """ 136 | input: torch.floatTensor (N, C, H, W) 137 | s: int (resample factor) 138 | """ 139 | 140 | assert( not input.size(2)%s and not input.size(3)%s) 141 | 142 | if input.size(1) == 3: 143 | # bgr2gray (same as opencv conversion matrix) 144 | input = (0.299 * input[:,2] + 0.587 * input[:,1] + 0.114 * input[:,0]).unsqueeze(1) 145 | 146 | out = torch.cat([input[:,:,i::s,j::s] for i in range(s) for j in range(s)], dim=1) 147 | 148 | """ 149 | out: torch.floatTensor (N, s**2, H/s, W/s) 150 | """ 151 | return out 152 | 153 | 154 | class Kernel_FDKN(nn.Module): 155 | def __init__(self, input_channel, kernel_size, factor=4): 156 | super(Kernel_FDKN, self).__init__() 157 | self.conv1 = nn.Conv2d(input_channel, 32, 3, padding=1) 158 | self.conv1_bn = nn.BatchNorm2d(32) 159 | self.conv2 = nn.Conv2d(32, 32, 3, padding=1) 160 | self.conv3 = nn.Conv2d(32, 64, 3, padding=1) 161 | self.conv3_bn = nn.BatchNorm2d(64) 162 | self.conv4 = nn.Conv2d(64, 64, 3, padding=1) 163 | self.conv5 = nn.Conv2d(64, 128, 3, padding=1) 164 | self.conv5_bn = nn.BatchNorm2d(128) 165 | self.conv6 = nn.Conv2d(128, 128, 3, padding=1) 166 | 167 | self.conv_weight = nn.Conv2d(128, kernel_size**2*(factor)**2, 1) 168 | self.conv_offset = nn.Conv2d(128, 2*kernel_size**2*(factor)**2, 1) 169 | 170 | def forward(self, x): 171 | x = F.relu(self.conv1_bn(self.conv1(x))) 172 | x = F.relu(self.conv2(x)) 173 | x = F.relu(self.conv3_bn(self.conv3(x))) 174 | x = F.relu(self.conv4(x)) 175 | x = F.relu(self.conv5_bn(self.conv5(x))) 176 | x = F.relu(self.conv6(x)) 177 | 178 | offset = self.conv_offset(x) 179 | weight = torch.sigmoid(self.conv_weight(x)) 180 | 181 | return weight, offset 182 | 183 | 184 | class FDKN(nn.Module): 185 | def __init__(self, kernel_size, filter_size, residual=True): 186 | super(FDKN, self).__init__() 187 | self.factor = 4 # resample factor 188 | self.ImageKernel = Kernel_FDKN(input_channel=16, kernel_size=kernel_size, factor=self.factor) 189 | self.DepthKernel = Kernel_FDKN(input_channel=16, kernel_size=kernel_size, factor=self.factor) 190 | self.residual = residual 191 | self.kernel_size = kernel_size 192 | self.filter_size = filter_size 193 | 194 | def forward(self, data): 195 | 196 | image = data['image'] 197 | depth = data['lr'] 198 | 199 | ### DKN assumes depth is normalized to [0, 1] and resized to full resolution. 200 | b, _, h, w = image.shape 201 | 202 | re_im = resample_data(image, self.factor) 203 | re_dp = resample_data(depth, self.factor) 204 | 205 | imkernel, imoffset = self.ImageKernel(re_im) 206 | depthkernel, depthoffset = self.DepthKernel(re_dp) 207 | 208 | weight = imkernel * depthkernel 209 | offset = imoffset * depthoffset 210 | 211 | ps = nn.PixelShuffle(4) 212 | weight = ps(weight) 213 | offset = ps(offset) 214 | 215 | if self.residual: 216 | weight -= torch.mean(weight, 1).unsqueeze(1).expand_as(weight) 217 | else: 218 | weight /= torch.sum(weight, 1).unsqueeze(1).expand_as(weight) 219 | 220 | b, h, w = image.size(0), image.size(2), image.size(3) 221 | k = self.filter_size 222 | r = self.kernel_size 223 | hw = h*w 224 | 225 | # weighted average 226 | # (b, 2*r**2, h, w) -> (b*hw, r, r, 2) 227 | offset = offset.permute(0,2,3,1).contiguous().view(b*hw, r,r, 2) 228 | # (b, r**2, h, w) -> (b*hw, r**2, 1) 229 | weight = weight.permute(0,2,3,1).contiguous().view(b*hw, r*r, 1) 230 | 231 | # (b*hw, r, r, 2) 232 | grid = grid_generator(k, r, b*hw) 233 | coord = grid + offset 234 | coord = (coord / k * 2) -1 235 | 236 | # (b, k**2, hw) -> (b*hw, 1, k, k) 237 | depth_col = F.unfold(depth, k, padding=k//2).permute(0,2,1).contiguous().view(b*hw, 1, k,k) 238 | 239 | # (b*hw, 1, k, k), (b*hw, r, r, 2) => (b*hw, 1, r^2) 240 | depth_sampled = F.grid_sample(depth_col, coord, align_corners=False).view(b*hw, 1, -1) 241 | 242 | # (b*w*h, 1, r^2) x (b*w*h, r^2, 1) => (b, 1, h, w) 243 | out = torch.bmm(depth_sampled, weight).view(b, 1, h,w) 244 | 245 | if self.residual: 246 | out += depth 247 | 248 | return out 249 | -------------------------------------------------------------------------------- /models/edsr.py: -------------------------------------------------------------------------------- 1 | # modified from: https://github.com/thstkdgus35/EDSR-PyTorch 2 | 3 | import math 4 | from argparse import Namespace 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 12 | return nn.Conv2d( 13 | in_channels, out_channels, kernel_size, 14 | padding=(kernel_size//2), bias=bias) 15 | 16 | class MeanShift(nn.Conv2d): 17 | def __init__( 18 | self, rgb_range, 19 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 20 | 21 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 22 | std = torch.Tensor(rgb_std) 23 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 24 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 25 | for p in self.parameters(): 26 | p.requires_grad = False 27 | 28 | class ResBlock(nn.Module): 29 | def __init__( 30 | self, conv, n_feats, kernel_size, 31 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 32 | 33 | super(ResBlock, self).__init__() 34 | m = [] 35 | for i in range(2): 36 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 37 | if bn: 38 | m.append(nn.BatchNorm2d(n_feats)) 39 | if i == 0: 40 | m.append(act) 41 | 42 | self.body = nn.Sequential(*m) 43 | self.res_scale = res_scale 44 | 45 | def forward(self, x): 46 | res = self.body(x).mul(self.res_scale) 47 | res += x 48 | 49 | return res 50 | 51 | class Upsampler(nn.Sequential): 52 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 53 | 54 | m = [] 55 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 56 | for _ in range(int(math.log(scale, 2))): 57 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 58 | m.append(nn.PixelShuffle(2)) 59 | if bn: 60 | m.append(nn.BatchNorm2d(n_feats)) 61 | if act == 'relu': 62 | m.append(nn.ReLU(True)) 63 | elif act == 'prelu': 64 | m.append(nn.PReLU(n_feats)) 65 | 66 | elif scale == 3: 67 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 68 | m.append(nn.PixelShuffle(3)) 69 | if bn: 70 | m.append(nn.BatchNorm2d(n_feats)) 71 | if act == 'relu': 72 | m.append(nn.ReLU(True)) 73 | elif act == 'prelu': 74 | m.append(nn.PReLU(n_feats)) 75 | else: 76 | raise NotImplementedError 77 | 78 | super(Upsampler, self).__init__(*m) 79 | 80 | 81 | url = { 82 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 83 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 84 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 85 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 86 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 87 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 88 | } 89 | 90 | class EDSR(nn.Module): 91 | def __init__(self, args, conv=default_conv): 92 | super(EDSR, self).__init__() 93 | self.args = args 94 | n_resblocks = args.n_resblocks 95 | n_feats = args.n_feats 96 | kernel_size = 3 97 | scale = args.scale[0] 98 | act = nn.ReLU(True) 99 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 100 | if url_name in url: 101 | self.url = url[url_name] 102 | else: 103 | self.url = None 104 | self.sub_mean = MeanShift(args.rgb_range) 105 | self.add_mean = MeanShift(args.rgb_range, sign=1) 106 | 107 | # define head module 108 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 109 | 110 | # define body module 111 | m_body = [ 112 | ResBlock( 113 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 114 | ) for _ in range(n_resblocks) 115 | ] 116 | m_body.append(conv(n_feats, n_feats, kernel_size)) 117 | 118 | self.head = nn.Sequential(*m_head) 119 | self.body = nn.Sequential(*m_body) 120 | 121 | if args.no_upsampling: 122 | self.out_dim = n_feats 123 | else: 124 | self.out_dim = args.n_colors 125 | # define tail module 126 | m_tail = [ 127 | Upsampler(conv, scale, n_feats, act=False), 128 | conv(n_feats, args.n_colors, kernel_size) 129 | ] 130 | self.tail = nn.Sequential(*m_tail) 131 | 132 | def forward(self, x): 133 | #x = self.sub_mean(x) 134 | x = self.head(x) 135 | 136 | res = self.body(x) 137 | res += x 138 | 139 | if self.args.no_upsampling: 140 | x = res 141 | else: 142 | x = self.tail(res) 143 | #x = self.add_mean(x) 144 | return x 145 | 146 | def load_state_dict(self, state_dict, strict=True): 147 | own_state = self.state_dict() 148 | for name, param in state_dict.items(): 149 | if name in own_state: 150 | if isinstance(param, nn.Parameter): 151 | param = param.data 152 | try: 153 | own_state[name].copy_(param) 154 | except Exception: 155 | if name.find('tail') == -1: 156 | raise RuntimeError('While copying the parameter named {}, ' 157 | 'whose dimensions in the model are {} and ' 158 | 'whose dimensions in the checkpoint are {}.' 159 | .format(name, own_state[name].size(), param.size())) 160 | elif strict: 161 | if name.find('tail') == -1: 162 | raise KeyError('unexpected key "{}" in state_dict' 163 | .format(name)) 164 | 165 | 166 | def make_edsr_baseline(n_resblocks=16, n_feats=64, res_scale=1, n_colors=1, 167 | scale=2, no_upsampling=True, rgb_range=1): 168 | args = Namespace() 169 | args.n_resblocks = n_resblocks 170 | args.n_feats = n_feats 171 | args.res_scale = res_scale 172 | 173 | args.scale = [scale] 174 | args.no_upsampling = no_upsampling 175 | 176 | args.rgb_range = rgb_range 177 | args.n_colors = n_colors 178 | return EDSR(args) 179 | 180 | 181 | def make_edsr(n_resblocks=32, n_feats=256, res_scale=0.1, n_colors=1, 182 | scale=2, no_upsampling=True, rgb_range=1): 183 | args = Namespace() 184 | args.n_resblocks = n_resblocks 185 | args.n_feats = n_feats 186 | args.res_scale = res_scale 187 | 188 | args.scale = [scale] 189 | args.no_upsampling = no_upsampling 190 | 191 | args.rgb_range = rgb_range 192 | args.n_colors = n_colors 193 | return EDSR(args) 194 | -------------------------------------------------------------------------------- /models/jiif.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from utils import make_coord 8 | from models.edsr import make_edsr_baseline 9 | 10 | class MLP(nn.Module): 11 | def __init__(self, in_dim, out_dim, hidden_list): 12 | super().__init__() 13 | layers = [] 14 | lastv = in_dim 15 | for hidden in hidden_list: 16 | layers.append(nn.Linear(lastv, hidden)) 17 | layers.append(nn.ReLU()) 18 | lastv = hidden 19 | layers.append(nn.Linear(lastv, out_dim)) 20 | self.layers = nn.Sequential(*layers) 21 | 22 | def forward(self, x): 23 | x = self.layers(x) 24 | return x 25 | 26 | 27 | class JIIF(nn.Module): 28 | 29 | def __init__(self, args, feat_dim=128, guide_dim=128, mlp_dim=[1024,512,256,128]): 30 | super().__init__() 31 | self.args = args 32 | self.feat_dim = feat_dim 33 | self.guide_dim = guide_dim 34 | self.mlp_dim = mlp_dim 35 | 36 | self.image_encoder = make_edsr_baseline(n_feats=self.guide_dim, n_colors=3) 37 | self.depth_encoder = make_edsr_baseline(n_feats=self.feat_dim, n_colors=1) 38 | 39 | imnet_in_dim = self.feat_dim + self.guide_dim * 2 + 2 40 | 41 | self.imnet = MLP(imnet_in_dim, out_dim=2, hidden_list=self.mlp_dim) 42 | 43 | def query(self, feat, coord, hr_guide, lr_guide, image): 44 | 45 | # feat: [B, C, h, w] 46 | # coord: [B, N, 2], N <= H * W 47 | 48 | b, c, h, w = feat.shape # lr 49 | B, N, _ = coord.shape 50 | 51 | # LR centers' coords 52 | feat_coord = make_coord((h, w), flatten=False).to(feat.device).permute(2, 0, 1).unsqueeze(0).expand(b, 2, h, w) 53 | 54 | q_guide_hr = F.grid_sample(hr_guide, coord.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)[:, :, 0, :].permute(0, 2, 1) # [B, N, C] 55 | 56 | rx = 1 / h 57 | ry = 1 / w 58 | 59 | preds = [] 60 | 61 | k = 0 62 | for vx in [-1, 1]: 63 | for vy in [-1, 1]: 64 | coord_ = coord.clone() 65 | 66 | coord_[:, :, 0] += (vx) * rx 67 | coord_[:, :, 1] += (vy) * ry 68 | k += 1 69 | 70 | # feat: [B, c, h, w], coord_: [B, N, 2] --> [B, 1, N, 2], out: [B, c, 1, N] --> [B, c, N] --> [B, N, c] 71 | q_feat = F.grid_sample(feat, coord_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)[:, :, 0, :].permute(0, 2, 1) # [B, N, c] 72 | q_coord = F.grid_sample(feat_coord, coord_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)[:, :, 0, :].permute(0, 2, 1) # [B, N, 2] 73 | 74 | rel_coord = coord - q_coord 75 | rel_coord[:, :, 0] *= h 76 | rel_coord[:, :, 1] *= w 77 | 78 | q_guide_lr = F.grid_sample(lr_guide, coord_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)[:, :, 0, :].permute(0, 2, 1) # [B, N, C] 79 | q_guide = torch.cat([q_guide_hr, q_guide_hr - q_guide_lr], dim=-1) 80 | 81 | inp = torch.cat([q_feat, q_guide, rel_coord], dim=-1) 82 | 83 | pred = self.imnet(inp.view(B * N, -1)).view(B, N, -1) # [B, N, 2] 84 | preds.append(pred) 85 | 86 | preds = torch.stack(preds, dim=-1) # [B, N, 2, kk] 87 | weight = F.softmax(preds[:,:,1,:], dim=-1) 88 | 89 | ret = (preds[:,:,0,:] * weight).sum(-1, keepdim=True) 90 | 91 | return ret 92 | 93 | def forward(self, data): 94 | image, depth, coord, res, lr_image = data['image'], data['lr'], data['hr_coord'], data['lr_pixel'], data['lr_image'] 95 | 96 | hr_guide = self.image_encoder(image) 97 | lr_guide = self.image_encoder(lr_image) 98 | 99 | feat = self.depth_encoder(depth) 100 | 101 | if self.training or not self.args.batched_eval: 102 | res = res + self.query(feat, coord, hr_guide, lr_guide, data['hr_depth'].repeat(1,3,1,1)) 103 | 104 | # batched evaluation to avoid OOM 105 | else: 106 | N = coord.shape[1] # coord ~ [B, N, 2] 107 | n = 30720 108 | tmp = [] 109 | for start in range(0, N, n): 110 | end = min(N, start + n) 111 | ans = self.query(feat, coord[:, start:end], hr_guide, lr_guide, data['hr_depth'].repeat(1,3,1,1)) # [B, N, 1] 112 | tmp.append(ans) 113 | res = res + torch.cat(tmp, dim=1) 114 | 115 | return res 116 | 117 | 118 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Joint Implicit Image Function for Guided Depth Super-Resolution 2 | 3 | This repository contains the code for: 4 | 5 | > [Joint Implicit Image Function for Guided Depth Super-Resolution](https://arxiv.org/abs/2107.08717) 6 | > Jiaxiang Tang, Xiaokang Chen, Gang Zeng 7 | > ACM MM 2021 8 | 9 | 10 | 11 | 12 | ![model](assets/model.png) 13 | 14 | 15 | 16 | ### Installation 17 | 18 | Environments: 19 | * Python >= 3.6 20 | * PyTorch >= 1.6.0 21 | * tensorboardX 22 | * tqdm, opencv-python, Pillow 23 | * [NVIDIA apex](https://github.com/NVIDIA/apex) (python-only build is ok.) 24 | 25 | 26 | 27 | ### Data preparation 28 | 29 | Please see [data/prepare_data.md](data/prepare_data.md) for the details. 30 | 31 | 32 | 33 | ### Training 34 | You can use the provided scripts (`scripts/train*`) to train models. 35 | 36 | For example: 37 | 38 | ```bash 39 | # train JIIF with scale = 8 on the NYU dataset. 40 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main.py \ 41 | --name jiif_8 --model JIIF --scale 8 \ 42 | --sample_q 30720 --input_size 256 --train_batch 1 \ 43 | --epoch 200 --eval_interval 10 \ 44 | --lr 0.0001 --lr_step 60 --lr_gamma 0.2 45 | ``` 46 | 47 | 48 | 49 | ### Testing 50 | 51 | To test the performance of the models on difference datasets, you can use the provided scripts (`scripts/test*`). 52 | 53 | For example: 54 | 55 | ```bash 56 | # test the best checkpoint on MiddleBury dataest with scale = 8 57 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py \ 58 | --test --checkpoint best \ 59 | --name jiif_8 --model JIIF \ 60 | --dataset Middlebury --scale 8 --data_root ./data/depth_enhance/01_Middlebury_Dataset 61 | ``` 62 | 63 | 64 | 65 | ### Pretrained models and Reproducing 66 | 67 | We provide the pretrained models [here](https://drive.google.com/drive/folders/1qU669OhhGcIgxYtj-1J6APZdUKQOZ4H2?usp=sharing). 68 | 69 | To test the performance of the pretrained models, please download the corresponding models and put them under `pretrained` folder. Then you can use `scripts/test_jiif_pretrained.sh` and `scripts/test_denoise_jiif_pretrained.sh` to reproduce the results reported in our paper. 70 | 71 | 72 | 73 | ### Citation 74 | 75 | If you find the code useful for your research, please use the following `BibTeX` entry: 76 | ``` 77 | @article{tang2021joint, 78 | title = {Joint Implicit Image Function for Guided Depth Super-Resolution}, 79 | author = {Jiaxiang Tang, Xiaokang Chen, Gang Zeng}, 80 | year = 2021, 81 | journal = {arXiv preprint arXiv:2107.08717} 82 | } 83 | ``` 84 | 85 | 86 | 87 | ### Acknowledgment 88 | 89 | The model implementation is based on [liif](https://github.com/yinboc/liif). -------------------------------------------------------------------------------- /scripts/test_denoise_jiif.sh: -------------------------------------------------------------------------------- 1 | set -euxo pipefail 2 | 3 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name denoise_jiif_4 --model JIIF --dataset NoisyMiddlebury --scale 4 --interpolation bicubic --data_root ./data/noisy_depth/middlebury --batched_eval --report_per_image 4 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name denoise_jiif_8 --model JIIF --dataset NoisyMiddlebury --scale 8 --interpolation bicubic --data_root ./data/noisy_depth/middlebury --batched_eval --report_per_image 5 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name denoise_jiif_16 --model JIIF --dataset NoisyMiddlebury --scale 16 --interpolation bicubic --data_root ./data/noisy_depth/middlebury --batched_eval --report_per_image -------------------------------------------------------------------------------- /scripts/test_denoise_jiif_pretrained.sh: -------------------------------------------------------------------------------- 1 | set -euxo pipefail 2 | 3 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/denoise_jiif_4.pth --name denoise_jiif_4 --model JIIF --dataset NoisyMiddlebury --scale 4 --interpolation bicubic --data_root ./data/noisy_depth/middlebury --batched_eval --report_per_image 4 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/denoise_jiif_8.pth --name denoise_jiif_8 --model JIIF --dataset NoisyMiddlebury --scale 8 --interpolation bicubic --data_root ./data/noisy_depth/middlebury --batched_eval --report_per_image 5 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/denoise_jiif_16.pth --name denoise_jiif_16 --model JIIF --dataset NoisyMiddlebury --scale 16 --interpolation bicubic --data_root ./data/noisy_depth/middlebury --batched_eval --report_per_image 6 | 7 | -------------------------------------------------------------------------------- /scripts/test_jiif.sh: -------------------------------------------------------------------------------- 1 | set -euxo pipefail 2 | 3 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name jiif_4 --model JIIF --dataset Middlebury --scale 4 --interpolation bicubic --data_root ./data/depth_enhance/01_Middlebury_Dataset 4 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name jiif_8 --model JIIF --dataset Middlebury --scale 8 --interpolation bicubic --data_root ./data/depth_enhance/01_Middlebury_Dataset 5 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name jiif_16 --model JIIF --dataset Middlebury --scale 16 --interpolation bicubic --data_root ./data/depth_enhance/01_Middlebury_Dataset 6 | 7 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name jiif_4 --model JIIF --dataset Lu --scale 4 --interpolation bicubic --data_root ./data/depth_enhance/03_RGBD_Dataset 8 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name jiif_8 --model JIIF --dataset Lu --scale 8 --interpolation bicubic --data_root ./data/depth_enhance/03_RGBD_Dataset 9 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name jiif_16 --model JIIF --dataset Lu --scale 16 --interpolation bicubic --data_root ./data/depth_enhance/03_RGBD_Dataset 10 | 11 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name jiif_4 --model JIIF --dataset NYU --scale 4 --interpolation bicubic --data_root ./data/nyu_labeled 12 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name jiif_8 --model JIIF --dataset NYU --scale 8 --interpolation bicubic --data_root ./data/nyu_labeled 13 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main.py --test --checkpoint best --name jiif_16 --model JIIF --dataset NYU --scale 16 --interpolation bicubic --data_root ./data/nyu_labeled 14 | -------------------------------------------------------------------------------- /scripts/test_jiif_pretrained.sh: -------------------------------------------------------------------------------- 1 | set -euxo pipefail 2 | 3 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/jiif_4.pth --name jiif_4 --model JIIF --dataset Middlebury --scale 4 --interpolation bicubic --data_root ./data/depth_enhance/01_Middlebury_Dataset 4 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/jiif_8.pth --name jiif_8 --model JIIF --dataset Middlebury --scale 8 --interpolation bicubic --data_root ./data/depth_enhance/01_Middlebury_Dataset 5 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/jiif_16.pth --name jiif_16 --model JIIF --dataset Middlebury --scale 16 --interpolation bicubic --data_root ./data/depth_enhance/01_Middlebury_Dataset 6 | 7 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/jiif_4.pth --name jiif_4 --model JIIF --dataset Lu --scale 4 --interpolation bicubic --data_root ./data/depth_enhance/03_RGBD_Dataset 8 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/jiif_8.pth --name jiif_8 --model JIIF --dataset Lu --scale 8 --interpolation bicubic --data_root ./data/depth_enhance/03_RGBD_Dataset 9 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/jiif_16.pth --name jiif_16 --model JIIF --dataset Lu --scale 16 --interpolation bicubic --data_root ./data/depth_enhance/03_RGBD_Dataset 10 | 11 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/jiif_4.pth --name jiif_4 --model JIIF --dataset NYU --scale 4 --interpolation bicubic --data_root ./data/nyu_labeled 12 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/jiif_8.pth --name jiif_8 --model JIIF --dataset NYU --scale 8 --interpolation bicubic --data_root ./data/nyu_labeled 13 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint ./pretrained/jiif_16.pth --name jiif_16 --model JIIF --dataset NYU --scale 16 --interpolation bicubic --data_root ./data/nyu_labeled -------------------------------------------------------------------------------- /scripts/train_denoise_jiif.sh: -------------------------------------------------------------------------------- 1 | set -euxo pipefail 2 | 3 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main.py --name denoise_jiif_4 --model JIIF --scale 4 --noisy --sample_q 30720 --input_size 256 --train_batch 1 --epoch 200 --eval_interval 10 --lr 0.0001 --lr_step 60 --lr_gamma 0.2 4 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main.py --name denoise_jiif_8 --model JIIF --scale 8 --noisy --sample_q 30720 --input_size 256 --train_batch 1 --epoch 200 --eval_interval 10 --lr 0.0001 --lr_step 60 --lr_gamma 0.2 5 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main.py --name denoise_jiif_16 --model JIIF --scale 16 --noisy --sample_q 30720 --input_size 256 --train_batch 1 --epoch 200 --eval_interval 10 --lr 0.0001 --lr_step 60 --lr_gamma 0.2 -------------------------------------------------------------------------------- /scripts/train_jiif.sh: -------------------------------------------------------------------------------- 1 | set -euxo pipefail 2 | 3 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main.py --name jiif_4 --model JIIF --scale 4 --sample_q 30720 --input_size 256 --train_batch 1 --epoch 200 --eval_interval 10 --lr 0.0001 --lr_step 60 --lr_gamma 0.2 4 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main.py --name jiif_8 --model JIIF --scale 8 --sample_q 30720 --input_size 256 --train_batch 1 --epoch 200 --eval_interval 10 --lr 0.0001 --lr_step 60 --lr_gamma 0.2 5 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main.py --name jiif_16 --model JIIF --scale 16 --sample_q 30720 --input_size 256 --train_batch 1 --epoch 200 --eval_interval 10 --lr 0.0001 --lr_step 60 --lr_gamma 0.2 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tqdm 4 | import random 5 | import tensorboardX 6 | 7 | import numpy as np 8 | import time 9 | import matplotlib.pyplot as plt 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import torch.nn.functional as F 15 | 16 | from apex import amp 17 | 18 | from PIL import Image 19 | 20 | # reference: icgNoiseLocalvar (https://github.com/griegler/primal-dual-networks/blob/master/common/icgcunn/IcgNoise.cu) 21 | def add_noise(x, k=1, sigma=651, inv=True): 22 | # x: [H, W, 1] 23 | noise = sigma * np.random.randn(*x.shape) 24 | if inv: 25 | noise = noise / (x + 1e-5) 26 | else: 27 | noise = noise * x 28 | x = x + k * noise 29 | return x 30 | 31 | def make_coord(shape, ranges=None, flatten=True): 32 | """ Make coordinates at grid centers. 33 | ranged in [-1, 1] 34 | e.g. 35 | shape = [2] get (-0.5, 0.5) 36 | shape = [3] get (-0.67, 0, 0.67) 37 | """ 38 | coord_seqs = [] 39 | for i, n in enumerate(shape): 40 | if ranges is None: 41 | v0, v1 = -1, 1 42 | else: 43 | v0, v1 = ranges[i] 44 | r = (v1 - v0) / (2 * n) 45 | seq = v0 + r + (2 * r) * torch.arange(n).float() 46 | coord_seqs.append(seq) 47 | ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) # [H, W, 2] 48 | if flatten: 49 | ret = ret.view(-1, ret.shape[-1]) # [H*W, 2] 50 | return ret 51 | 52 | def to_pixel_samples(depth): 53 | """ Convert the image to coord-RGB pairs. 54 | depth: Tensor, (1, H, W) 55 | """ 56 | coord = make_coord(depth.shape[-2:], flatten=True) # [H*W, 2] 57 | pixel = depth.view(-1, 1) # [H*W, 1] 58 | return coord, pixel 59 | 60 | def seed_everything(seed): 61 | random.seed(seed) 62 | os.environ['PYTHONHASHSEED'] = str(seed) 63 | np.random.seed(seed) 64 | torch.manual_seed(seed) 65 | torch.cuda.manual_seed(seed) 66 | torch.backends.cudnn.deterministic = True 67 | torch.backends.cudnn.benchmark = True 68 | 69 | 70 | def visualize_2d(x, batched=False, renormalize=False): 71 | # x: [B, 3, H, W] or [B, 1, H, W] or [B, H, W] 72 | import matplotlib.pyplot as plt 73 | import numpy as np 74 | import torch 75 | 76 | if batched: 77 | x = x[0] 78 | 79 | if isinstance(x, torch.Tensor): 80 | x = x.detach().cpu().numpy() 81 | 82 | if len(x.shape) == 3: 83 | if x.shape[0] == 3: 84 | x = x.transpose(1, 2, 0) # to channel last 85 | elif x.shape[0] == 1: 86 | x = x[0] # to grey 87 | 88 | print(f'[VISUALIZER] {x.shape}, {x.min()} ~ {x.max()}') 89 | 90 | x = x.astype(np.float32) 91 | 92 | if len(x.shape) == 3: 93 | x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8) 94 | 95 | plt.matshow(x) 96 | plt.show() 97 | 98 | 99 | class RMSEMeter: 100 | def __init__(self, args): 101 | self.args = args 102 | self.V = 0 103 | self.N = 0 104 | 105 | def clear(self): 106 | self.V = 0 107 | self.N = 0 108 | 109 | def prepare_inputs(self, *inputs): 110 | outputs = [] 111 | for i, inp in enumerate(inputs): 112 | if torch.is_tensor(inp): 113 | inp = inp.detach().cpu().numpy() 114 | outputs.append(inp) 115 | 116 | return outputs 117 | 118 | def update(self, data, preds, truths, eval=False): 119 | preds, truths = self.prepare_inputs(preds, truths) # [B, 1, H, W] 120 | 121 | if eval: 122 | B, C, H, W = data['image'].shape 123 | preds = preds.reshape(B, 1, H, W) 124 | truths = truths.reshape(B, 1, H, W) 125 | 126 | # clip borders (reference: https://github.com/cvlab-yonsei/dkn/issues/1) 127 | preds = preds[:, :, 6:-6, 6:-6] 128 | truths = truths[:, :, 6:-6, 6:-6] 129 | 130 | # rmse 131 | rmse = np.sqrt(np.mean(np.power(preds - truths, 2))) 132 | 133 | # to report per-image rmse 134 | if self.args.report_per_image: 135 | print('rmse = ', rmse) 136 | 137 | self.V += rmse 138 | self.N += 1 139 | 140 | def measure(self): 141 | return self.V / self.N 142 | 143 | def write(self, writer, global_step, prefix=""): 144 | writer.add_scalar(os.path.join(prefix, "rmse"), self.measure(), global_step) 145 | 146 | def report(self): 147 | return f'RMSE = {self.measure():.6f}' 148 | 149 | 150 | class Trainer(object): 151 | def __init__(self, 152 | args, 153 | name, # name of this experiment 154 | model, # network 155 | objective=None, # loss function, if None, assume inline implementation in train_step 156 | optimizer=None, # optimizer 157 | lr_scheduler=None, # scheduler 158 | metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. 159 | local_rank=0, # which GPU am I 160 | world_size=1, # total num of GPUs 161 | device=None, # device to use, usually setting to None is OK. (auto choose device) 162 | mute=False, # whether to mute all print 163 | opt_level='O0', # amp optimize level 164 | eval_interval=1, # eval once every $ epoch 165 | max_keep_ckpt=1, # max num of saved ckpts in disk 166 | workspace='workspace', # workspace to save logs & ckpts 167 | best_mode='min', # the smaller/larger result, the better 168 | use_loss_as_metric=False, # use loss as the first metirc 169 | use_checkpoint="latest", # which ckpt to use at init time 170 | use_tensorboardX=True, # whether to use tensorboard for logging 171 | scheduler_update_every_step=False, # whether to call scheduler.step() after every train step 172 | ): 173 | 174 | self.args = args 175 | self.name = name 176 | self.mute = mute 177 | self.model = model 178 | self.objective = objective 179 | self.optimizer = optimizer 180 | self.lr_scheduler = lr_scheduler 181 | self.metrics = metrics 182 | self.local_rank = local_rank 183 | self.world_size = world_size 184 | self.workspace = workspace 185 | self.opt_level = opt_level 186 | self.best_mode = best_mode 187 | self.use_loss_as_metric = use_loss_as_metric 188 | self.max_keep_ckpt = max_keep_ckpt 189 | self.eval_interval = eval_interval 190 | self.use_checkpoint = use_checkpoint 191 | self.use_tensorboardX = use_tensorboardX 192 | self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") 193 | self.scheduler_update_every_step = scheduler_update_every_step 194 | self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') 195 | 196 | self.model.to(self.device) 197 | if isinstance(self.objective, nn.Module): 198 | self.objective.to(self.device) 199 | 200 | if optimizer is None: 201 | self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam 202 | 203 | if lr_scheduler is None: 204 | self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler 205 | 206 | self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=self.opt_level, verbosity=0) 207 | 208 | # variable init 209 | self.epoch = 1 210 | self.global_step = 0 211 | self.local_step = 0 212 | self.stats = { 213 | "loss": [], 214 | "valid_loss": [], 215 | "results": [], # metrics[0], or valid_loss 216 | "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt 217 | "best_result": None, 218 | } 219 | 220 | # auto fix 221 | if len(metrics) == 0 or self.use_loss_as_metric: 222 | self.best_mode = 'min' 223 | 224 | # workspace prepare 225 | self.log_ptr = None 226 | if self.workspace is not None: 227 | os.makedirs(self.workspace, exist_ok=True) 228 | self.log_path = os.path.join(workspace, f"log_{self.name}.txt") 229 | self.log_ptr = open(self.log_path, "a+") 230 | 231 | self.ckpt_path = os.path.join(self.workspace, 'checkpoints') 232 | self.best_path = f"{self.ckpt_path}/{self.name}.pth.tar" 233 | os.makedirs(self.ckpt_path, exist_ok=True) 234 | 235 | self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {self.workspace}') 236 | self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}') 237 | 238 | if self.workspace is not None: 239 | if self.use_checkpoint == "scratch": 240 | self.log("[INFO] Model randomly initialized ...") 241 | elif self.use_checkpoint == "latest": 242 | self.log("[INFO] Loading latest checkpoint ...") 243 | self.load_checkpoint() 244 | elif self.use_checkpoint == "best": 245 | if os.path.exists(self.best_path): 246 | self.log("[INFO] Loading best checkpoint ...") 247 | self.load_checkpoint(self.best_path) 248 | else: 249 | self.log(f"[INFO] {self.best_path} not found, loading latest ...") 250 | self.load_checkpoint() 251 | else: # path to ckpt 252 | self.log(f"[INFO] Loading {self.use_checkpoint} ...") 253 | self.load_checkpoint(self.use_checkpoint) 254 | 255 | def __del__(self): 256 | if self.log_ptr: 257 | self.log_ptr.close() 258 | 259 | def log(self, *args): 260 | if self.local_rank == 0: 261 | if not self.mute: 262 | print(*args) 263 | if self.log_ptr: 264 | print(*args, file=self.log_ptr) 265 | 266 | ### ------------------------------ 267 | 268 | def train_step(self, data): 269 | gt = data['hr'] 270 | pred = self.model(data) 271 | 272 | loss = self.objective(pred, gt) 273 | 274 | # rescale 275 | pred = pred * (data['max'] - data['min']) + data['min'] 276 | gt = gt * (data['max'] - data['min']) + data['min'] 277 | 278 | return pred, gt, loss 279 | 280 | def eval_step(self, data): 281 | return self.train_step(data) 282 | 283 | def test_step(self, data): 284 | B, C, H, W = data['image'].shape 285 | pred = self.model(data) 286 | pred = pred * (data['max'] - data['min']) + data['min'] 287 | pred = pred.reshape(B, 1, H, W) 288 | 289 | #visualize_2d(data['image'], batched=True) 290 | #visualize_2d(data['lr'], batched=True) 291 | #visualize_2d(pred, batched=True) 292 | 293 | return pred 294 | 295 | ### ------------------------------ 296 | 297 | def train(self, train_loader, valid_loader, max_epochs): 298 | if self.use_tensorboardX and self.local_rank == 0: 299 | self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name)) 300 | 301 | for epoch in range(self.epoch, max_epochs + 1): 302 | self.epoch = epoch 303 | self.train_one_epoch(train_loader) 304 | 305 | if self.workspace is not None and self.local_rank == 0: 306 | self.save_checkpoint(full=True, best=False) 307 | 308 | if self.epoch % self.eval_interval == 0: 309 | self.evaluate_one_epoch(valid_loader) 310 | self.save_checkpoint(full=False, best=True) 311 | 312 | if self.use_tensorboardX and self.local_rank == 0: 313 | self.writer.close() 314 | 315 | def evaluate(self, loader): 316 | #if os.path.exists(self.best_path): 317 | # self.load_checkpoint(self.best_path) 318 | #else: 319 | # self.load_checkpoint() 320 | self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX 321 | self.evaluate_one_epoch(loader) 322 | self.use_tensorboardX = use_tensorboardX 323 | 324 | def test(self, loader, save_path=None): 325 | if save_path is None: 326 | save_path = os.path.join(self.workspace, 'results', f'{self.name}_{self.args.dataset}_{self.args.scale}') 327 | os.makedirs(save_path, exist_ok=True) 328 | 329 | self.log(f"==> Start Test, save results to {save_path}") 330 | 331 | 332 | pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') 333 | self.model.eval() 334 | with torch.no_grad(): 335 | for data in loader: 336 | 337 | data = self.prepare_data(data) 338 | preds = self.test_step(data) 339 | 340 | preds = preds.detach().cpu().numpy() # [B, 1, H, W] 341 | 342 | for b in range(preds.shape[0]): 343 | idx = data['idx'][b] 344 | if not isinstance(idx, str): 345 | idx = str(idx.item()) 346 | pred = preds[b][0] 347 | plt.imsave(os.path.join(save_path, f'{idx}.png'), pred, cmap='plasma') 348 | 349 | pbar.update(loader.batch_size) 350 | 351 | self.log(f"==> Finished Test.") 352 | 353 | def prepare_data(self, data): 354 | if isinstance(data, list): 355 | for i, v in enumerate(data): 356 | if isinstance(v, np.ndarray): 357 | data[i] = torch.from_numpy(v).to(self.device) 358 | if torch.is_tensor(v): 359 | data[i] = v.to(self.device) 360 | elif isinstance(data, dict): 361 | for k, v in data.items(): 362 | if isinstance(v, np.ndarray): 363 | data[k] = torch.from_numpy(v).to(self.device) 364 | if torch.is_tensor(v): 365 | data[k] = v.to(self.device) 366 | elif isinstance(data, np.ndarray): 367 | data = torch.from_numpy(data).to(self.device) 368 | else: # is_tensor 369 | data = data.to(self.device) 370 | 371 | return data 372 | 373 | def train_one_epoch(self, loader): 374 | self.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']} ...") 375 | 376 | total_loss = [] 377 | if self.local_rank == 0: 378 | for metric in self.metrics: 379 | metric.clear() 380 | 381 | self.model.train() 382 | 383 | if self.local_rank == 0: 384 | pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') 385 | 386 | self.local_step = 0 387 | 388 | for data in loader: 389 | 390 | self.local_step += 1 391 | self.global_step += 1 392 | 393 | data = self.prepare_data(data) 394 | preds, truths, loss = self.train_step(data) 395 | 396 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 397 | scaled_loss.backward() 398 | 399 | self.optimizer.step() 400 | self.optimizer.zero_grad() 401 | 402 | if self.scheduler_update_every_step: 403 | self.lr_scheduler.step() 404 | 405 | total_loss.append(loss.item()) 406 | if self.local_rank == 0: 407 | for metric in self.metrics: 408 | metric.update(data, preds, truths) 409 | 410 | if self.use_tensorboardX: 411 | self.writer.add_scalar("train/loss", loss.item(), self.global_step) 412 | self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step) 413 | 414 | if self.scheduler_update_every_step: 415 | pbar.set_description(f"loss={total_loss[-1]:.4f}, lr={self.optimizer.param_groups[0]['lr']}") 416 | else: 417 | pbar.set_description(f'loss={total_loss[-1]:.4f}') 418 | pbar.update(loader.batch_size * self.world_size) 419 | 420 | average_loss = np.mean(total_loss) 421 | self.stats["loss"].append(average_loss) 422 | 423 | if self.local_rank == 0: 424 | pbar.close() 425 | for metric in self.metrics: 426 | self.log(metric.report()) 427 | if self.use_tensorboardX: 428 | metric.write(self.writer, self.epoch, prefix="train") 429 | metric.clear() 430 | 431 | if not self.scheduler_update_every_step: 432 | if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 433 | self.lr_scheduler.step(average_loss) 434 | else: 435 | self.lr_scheduler.step() 436 | 437 | self.log(f"==> Finished Epoch {self.epoch}, average_loss={average_loss:.4f}") 438 | 439 | 440 | def evaluate_one_epoch(self, loader): 441 | self.log(f"++> Evaluate at epoch {self.epoch} ...") 442 | 443 | total_loss = [] 444 | if self.local_rank == 0: 445 | for metric in self.metrics: 446 | metric.clear() 447 | 448 | self.model.eval() 449 | 450 | if self.local_rank == 0: 451 | pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') 452 | 453 | with torch.no_grad(): 454 | self.local_step = 0 455 | for data in loader: 456 | self.local_step += 1 457 | 458 | data = self.prepare_data(data) 459 | preds, truths, loss = self.eval_step(data) 460 | 461 | total_loss.append(loss.item()) 462 | if self.local_rank == 0: 463 | for metric in self.metrics: 464 | metric.update(data, preds, truths, eval=True) 465 | 466 | pbar.set_description(f'loss={total_loss[-1]:.4f}') 467 | pbar.update(loader.batch_size * self.world_size) 468 | 469 | average_loss = np.mean(total_loss) 470 | self.stats["valid_loss"].append(average_loss) 471 | 472 | if self.local_rank == 0: 473 | pbar.close() 474 | if not self.use_loss_as_metric and len(self.metrics) > 0: 475 | result = self.metrics[0].measure() 476 | self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result 477 | else: 478 | self.stats["results"].append(average_loss) # if no metric, choose best by min loss 479 | 480 | for metric in self.metrics: 481 | self.log(metric.report()) 482 | if self.use_tensorboardX: 483 | metric.write(self.writer, self.epoch, prefix="evaluate") 484 | metric.clear() 485 | 486 | self.log(f"++> Evaluate epoch {self.epoch} Finished, average_loss={average_loss:.4f}") 487 | 488 | def save_checkpoint(self, full=False, best=False): 489 | 490 | state = { 491 | 'epoch': self.epoch, 492 | 'stats': self.stats, 493 | 'model': self.model.state_dict(), 494 | } 495 | 496 | if full: 497 | state['amp'] = amp.state_dict() 498 | state['optimizer'] = self.optimizer.state_dict() 499 | state['lr_scheduler'] = self.lr_scheduler.state_dict() 500 | 501 | if not best: 502 | 503 | file_path = f"{self.ckpt_path}/{self.name}_ep{self.epoch:04d}.pth.tar" 504 | 505 | self.stats["checkpoints"].append(file_path) 506 | 507 | if len(self.stats["checkpoints"]) > self.max_keep_ckpt: 508 | old_ckpt = self.stats["checkpoints"].pop(0) 509 | if os.path.exists(old_ckpt): 510 | os.remove(old_ckpt) 511 | 512 | torch.save(state, file_path) 513 | 514 | else: 515 | if len(self.stats["results"]) > 0: 516 | if self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"]: 517 | self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}") 518 | self.stats["best_result"] = self.stats["results"][-1] 519 | torch.save(state, self.best_path) 520 | else: 521 | self.log(f"[INFO] no evaluated results found, skip saving best checkpoint.") 522 | 523 | def load_checkpoint(self, checkpoint=None): 524 | if checkpoint is None: 525 | checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/{self.name}_ep*.pth.tar')) 526 | if checkpoint_list: 527 | checkpoint = checkpoint_list[-1] 528 | else: 529 | self.log("[INFO] No checkpoint found, model randomly initialized.") 530 | return 531 | 532 | checkpoint_dict = torch.load(checkpoint, map_location=self.device) 533 | 534 | if 'model' not in checkpoint_dict: 535 | self.model.load_state_dict(checkpoint_dict) 536 | return 537 | 538 | self.model.load_state_dict(checkpoint_dict['model']) 539 | 540 | self.stats = checkpoint_dict['stats'] 541 | self.epoch = checkpoint_dict['epoch'] 542 | 543 | if self.optimizer and 'optimizer' in checkpoint_dict: 544 | try: 545 | self.optimizer.load_state_dict(checkpoint_dict['optimizer']) 546 | self.log("[INFO] loaded optimizer.") 547 | except: 548 | self.log("[WARN] Failed to load optimizer. Skipped.") 549 | 550 | if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: 551 | try: 552 | self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) 553 | self.log("[INFO] loaded scheduler.") 554 | except: 555 | self.log("[WARN] Failed to load scheduler. Skipped.") 556 | 557 | if 'amp' in checkpoint_dict: 558 | amp.load_state_dict(checkpoint_dict['amp']) 559 | self.log("[INFO] loaded amp.") --------------------------------------------------------------------------------