├── .gitignore ├── README.md ├── data_processing ├── __init__.py ├── camera_pipeline.py └── synthetic_burst_generation.py ├── datasets ├── __init__.py ├── burstsr_dataset.py ├── burstsr_test_dataset.py ├── synthetic_burst_test_set.py ├── synthetic_burst_train_set.py ├── synthetic_burst_val_set.py └── zurich_raw2rgb_dataset.py ├── figs └── ts.png ├── loss ├── Charbonnier.py ├── __init__.py ├── adversarial.py ├── discriminator.py ├── filter.py ├── hist_entropy.py ├── mssim.py └── vgg.py ├── main.py ├── model ├── DCNv2 │ ├── DCNv2.egg-info │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ └── top_level.txt │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ │ └── dcn_v2.cpython-37.pyc │ ├── build │ │ ├── lib.linux-x86_64-3.7 │ │ │ └── _ext.cpython-37m-x86_64-linux-gnu.so │ │ └── temp.linux-x86_64-3.7 │ │ │ └── data │ │ │ └── work │ │ │ └── pylibs │ │ │ └── DCNv2-pytorch_1.6 │ │ │ └── src │ │ │ ├── cpu │ │ │ ├── dcn_v2_cpu.o │ │ │ ├── dcn_v2_im2col_cpu.o │ │ │ └── dcn_v2_psroi_pooling_cpu.o │ │ │ ├── cuda │ │ │ ├── dcn_v2_cuda.o │ │ │ ├── dcn_v2_im2col_cuda.o │ │ │ └── dcn_v2_psroi_pooling_cuda.o │ │ │ └── vision.o │ ├── dcn_v2.py │ ├── dist │ │ └── DCNv2-0.1-py3.7-linux-x86_64.egg │ ├── files.txt │ ├── make.sh │ ├── setup.py │ ├── src │ │ ├── cpu │ │ │ ├── dcn_v2_cpu.cpp │ │ │ ├── dcn_v2_im2col_cpu.cpp │ │ │ ├── dcn_v2_im2col_cpu.h │ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp │ │ │ └── vision.h │ │ ├── cuda │ │ │ ├── dcn_v2_cuda.cu │ │ │ ├── dcn_v2_im2col_cuda.cu │ │ │ ├── dcn_v2_im2col_cuda.h │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu │ │ │ └── vision.h │ │ ├── dcn_v2.h │ │ └── vision.cpp │ └── test.py ├── __init__.py ├── arch_util.py ├── common.py ├── ebsr.py ├── non_local │ ├── network.py │ ├── non_local_concatenation.py │ ├── non_local_cross_dot_product.py │ ├── non_local_dot_product.py │ ├── non_local_embedded_gaussian.py │ └── non_local_gaussian.py └── utils │ ├── interp_methods.py │ ├── psconv.py │ └── resize_right.py ├── option.py ├── pwcnet ├── LICENSE ├── README.md ├── __init__.py ├── comparison │ ├── comparison.gif │ ├── comparison.py │ ├── official - caffe.png │ └── this - pytorch.png ├── correlation │ ├── README.md │ ├── __pycache__ │ │ └── correlation.cpython-37.pyc │ └── correlation.py ├── download.bash ├── images │ ├── README.md │ ├── first.png │ └── second.png ├── out.flo ├── pwcnet.py ├── requirements.txt └── run.py ├── requirements.txt ├── scripts ├── __init__.py ├── cal_mean_std.py ├── demo.sh ├── download_burstsr_dataset.py ├── evaluate.sh ├── evaluate_burstsr_val.py ├── save_results_synburst_val.py ├── test_burstsr_dataset.py └── test_synthetic_bursts.py ├── test.py ├── test_real.py ├── trainer.py ├── utility.py └── utils ├── __init__.py ├── data_format_utils.py ├── debayer.py ├── interp_methods.py ├── metrics.py ├── postprocessing_functions.py ├── resize_right.py ├── spatial_color_alignment.py ├── stn.py └── warp.py /.gitignore: -------------------------------------------------------------------------------- 1 | demo.sh 2 | /checkpoints 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EBSR: Feature Enhanced Burst Super-Resolution With Deformable Alignment (CVPRW 2021) 2 | 3 | 4 | ### Update !!! 5 | - **2022.04.22** 🎉🎉🎉 We won the 1st place in NTIRE 2022 BurstSR Challenge again [[Paper]](https://arxiv.org/abs/2204.08332)[[Code]](https://github.com/Algolzw/BSRT). 6 | - **2022.01.22** We updated the code to support real track testing and provided the model weights [here](https://drive.google.com/file/d/1Zz21YwNtiKZCjerrZsdvcWyubqTJBwaD/view?usp=sharing) 7 | - **2021** Now we support 1 GPU training and provide the pretrained model [here](https://drive.google.com/file/d/1_WA2chhITIsCj6qImcEM2lD6c-iJsRpy/view?usp=sharing). 8 | 9 | 10 | 11 |
12 | 13 |
14 | 15 | This repository is an official PyTorch implementation of the paper **"EBSR: Feature Enhanced Burst Super-Resolution With Deformable Alignment"** from CVPRW 2021, 1st NTIRE21 Burst SR in real track (2nd in synthetic track). 16 | 17 | ## Dependencies 18 | - OS: Ubuntu 18.04 19 | - Python: Python 3.7 20 | - nvidia : 21 | - cuda: 10.1 22 | - cudnn: 7.6.1 23 | - Other reference requirements 24 | 25 | ## Quick Start 26 | 1.Create a conda virtual environment and activate it 27 | ```python3 28 | conda create -n pytorch_1.6 python=3.7 29 | source activate pytorch_1.6 30 | ``` 31 | 2.Install PyTorch and torchvision following the official instructions 32 | ```python3 33 | conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch 34 | ``` 35 | 3.Install build requirements 36 | ```python3 37 | pip3 install -r requirements.txt 38 | ``` 39 | 4.Install apex to use DistributedDataParallel following the [Nvidia apex](https://github.com/NVIDIA/apex) (optional) 40 | ```python3 41 | git clone https://github.com/NVIDIA/apex 42 | cd apex 43 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 44 | ``` 45 | 5.Install DCN 46 | ```python3 47 | cd DCNv2-pytorch_1.6 48 | python3 setup.py build develop # build 49 | python3 test.py # run examples and check 50 | ``` 51 | ## Training 52 | ```python3 53 | # Modify the root path of training dataset and model etc. 54 | # The number of GPUs should be more than 1 55 | python main.py --n_GPUs 4 --lr 0.0002 --decay 200-400 --save ebsr --model EBSR --fp16 --lrcn --non_local --n_feats 128 --n_resblocks 8 --n_resgroups 5 --batch_size 16 --burst_size 14 --patch_size 256 --scale 4 --loss 1*L1 56 | ``` 57 | ## Test 58 | ```python3 59 | # Modify the path of test dataset and the path of the trained model 60 | python test.py --root /data/dataset/ntire21/burstsr/synthetic/syn_burst_val --model EBSR --lrcn --non_local --n_feats 128 --n_resblocks 8 --n_resgroups 5 --burst_size 14 --scale 4 --pre_train ./checkpoints/EBSRbest_epoch.pth 61 | ``` 62 | or test on the validation dataset: 63 | ```python3 64 | python main.py --n_GPUs 1 --test_only --model EBSR --lrcn --non_local --n_feats 128 --n_resblocks 8 --n_resgroups 5 --burst_size 14 --scale 4 --pre_train ./checkpoints/EBSRbest_epoch.pth 65 | ``` 66 | ### Real track evaluation 67 | You may need to download pretrained PWC model to the pwcnet directory ([here](https://drive.google.com/file/d/1dD6vB9QN3qwmOBi3AGKzJbbSojwDDlgV/view?usp=sharing)). 68 | 69 | ``` 70 | python test_real.py --n_GPUs 1 --model EBSR --lrcn --non_local --n_feats 128 --n_resblocks 8 --n_resgroups 5 --burst_size 14 --scale 4 --pre_train ./checkpoints/BBSR_realbest_epoch.pth --root burstsr_validation_dataset... 71 | 72 | ``` 73 | 74 | ## Citations 75 | If EBSR helps your research or work, please consider citing EBSR. 76 | The following is a BibTeX reference. 77 | 78 | ``` 79 | @InProceedings{Luo_2021_CVPR, 80 | author = {Luo, Ziwei and Yu, Lei and Mo, Xuan and Li, Youwei and Jia, Lanpeng and Fan, Haoqiang and Sun, Jian and Liu, Shuaicheng}, 81 | title = {EBSR: Feature Enhanced Burst Super-Resolution With Deformable Alignment}, 82 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 83 | month = {June}, 84 | year = {2021}, 85 | pages = {471-478} 86 | } 87 | ``` 88 | 89 | ## Contact 90 | email: [ziwei.ro@gmail.com, yl_yjsy@163.com] 91 | -------------------------------------------------------------------------------- /data_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/data_processing/__init__.py -------------------------------------------------------------------------------- /data_processing/camera_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import math 4 | import cv2 as cv 5 | import numpy as np 6 | import utils.data_format_utils as df_utils 7 | """ Based on http://timothybrooks.com/tech/unprocessing 8 | Functions for forward and inverse camera pipeline. All functions input a torch float tensor of shape (c, h, w). 9 | Additionally, some also support batch operations, i.e. inputs of shape (b, c, h, w) 10 | """ 11 | 12 | 13 | def random_ccm(): 14 | """Generates random RGB -> Camera color correction matrices.""" 15 | # Takes a random convex combination of XYZ -> Camera CCMs. 16 | xyz2cams = [[[1.0234, -0.2969, -0.2266], 17 | [-0.5625, 1.6328, -0.0469], 18 | [-0.0703, 0.2188, 0.6406]], 19 | [[0.4913, -0.0541, -0.0202], 20 | [-0.613, 1.3513, 0.2906], 21 | [-0.1564, 0.2151, 0.7183]], 22 | [[0.838, -0.263, -0.0639], 23 | [-0.2887, 1.0725, 0.2496], 24 | [-0.0627, 0.1427, 0.5438]], 25 | [[0.6596, -0.2079, -0.0562], 26 | [-0.4782, 1.3016, 0.1933], 27 | [-0.097, 0.1581, 0.5181]]] 28 | 29 | num_ccms = len(xyz2cams) 30 | xyz2cams = torch.tensor(xyz2cams) 31 | 32 | weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(0.0, 1.0) 33 | weights_sum = weights.sum() 34 | xyz2cam = (xyz2cams * weights).sum(dim=0) / weights_sum 35 | 36 | # Multiplies with RGB -> XYZ to get RGB -> Camera CCM. 37 | rgb2xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375], 38 | [0.2126729, 0.7151522, 0.0721750], 39 | [0.0193339, 0.1191920, 0.9503041]]) 40 | rgb2cam = torch.mm(xyz2cam, rgb2xyz) 41 | 42 | # Normalizes each row. 43 | rgb2cam = rgb2cam / rgb2cam.sum(dim=-1, keepdims=True) 44 | return rgb2cam 45 | 46 | 47 | def random_gains(): 48 | """Generates random gains for brightening and white balance.""" 49 | # RGB gain represents brightening. 50 | rgb_gain = 1.0 / random.gauss(mu=0.8, sigma=0.1) 51 | 52 | # Red and blue gains represent white balance. 53 | red_gain = random.uniform(1.9, 2.4) 54 | blue_gain = random.uniform(1.5, 1.9) 55 | return rgb_gain, red_gain, blue_gain 56 | 57 | 58 | def apply_smoothstep(image): 59 | """Apply global tone mapping curve.""" 60 | image_out = 3 * image**2 - 2 * image**3 61 | return image_out 62 | 63 | 64 | def invert_smoothstep(image): 65 | """Approximately inverts a global tone mapping curve.""" 66 | image = image.clamp(0.0, 1.0) 67 | return 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0) 68 | 69 | 70 | def gamma_expansion(image): 71 | """Converts from gamma to linear space.""" 72 | # Clamps to prevent numerical instability of gradients near zero. 73 | return image.clamp(1e-8) ** 2.2 74 | 75 | 76 | def gamma_compression(image): 77 | """Converts from linear to gammaspace.""" 78 | # Clamps to prevent numerical instability of gradients near zero. 79 | return image.clamp(1e-8) ** (1.0 / 2.2) 80 | 81 | 82 | def apply_ccm(image, ccm): 83 | """Applies a color correction matrix.""" 84 | assert image.dim() == 3 and image.shape[0] == 3 85 | 86 | shape = image.shape 87 | image = image.view(3, -1) 88 | ccm = ccm.to(image.device).type_as(image) 89 | 90 | image = torch.mm(ccm, image) 91 | 92 | return image.view(shape) 93 | 94 | 95 | def apply_gains(image, rgb_gain, red_gain, blue_gain): 96 | """Inverts gains while safely handling saturated pixels.""" 97 | assert image.dim() == 3 and image.shape[0] in [3, 4] 98 | 99 | if image.shape[0] == 3: 100 | gains = torch.tensor([red_gain, 1.0, blue_gain]) * rgb_gain 101 | else: 102 | gains = torch.tensor([red_gain, 1.0, 1.0, blue_gain]) * rgb_gain 103 | gains = gains.view(-1, 1, 1) 104 | gains = gains.to(image.device).type_as(image) 105 | 106 | return (image * gains).clamp(0.0, 1.0) 107 | 108 | 109 | def safe_invert_gains(image, rgb_gain, red_gain, blue_gain): 110 | """Inverts gains while safely handling saturated pixels.""" 111 | assert image.dim() == 3 and image.shape[0] == 3 112 | 113 | gains = torch.tensor([1.0 / red_gain, 1.0, 1.0 / blue_gain]) / rgb_gain 114 | gains = gains.view(-1, 1, 1) 115 | 116 | # Prevents dimming of saturated pixels by smoothly masking gains near white. 117 | gray = image.mean(dim=0, keepdims=True) 118 | inflection = 0.9 119 | mask = ((gray - inflection).clamp(0.0) / (1.0 - inflection)) ** 2.0 120 | 121 | safe_gains = torch.max(mask + (1.0 - mask) * gains, gains) 122 | return image * safe_gains 123 | 124 | 125 | def mosaic(image, mode='rggb'): 126 | """Extracts RGGB Bayer planes from an RGB image.""" 127 | shape = image.shape 128 | if image.dim() == 3: 129 | image = image.unsqueeze(0) 130 | 131 | if mode == 'rggb': 132 | red = image[:, 0, 0::2, 0::2] 133 | green_red = image[:, 1, 0::2, 1::2] 134 | green_blue = image[:, 1, 1::2, 0::2] 135 | blue = image[:, 2, 1::2, 1::2] 136 | image = torch.stack((red, green_red, green_blue, blue), dim=1) 137 | elif mode == 'grbg': 138 | green_red = image[:, 1, 0::2, 0::2] 139 | red = image[:, 0, 0::2, 1::2] 140 | blue = image[:, 2, 0::2, 1::2] 141 | green_blue = image[:, 1, 1::2, 1::2] 142 | 143 | image = torch.stack((green_red, red, blue, green_blue), dim=1) 144 | 145 | if len(shape) == 3: 146 | return image.view((4, shape[-2] // 2, shape[-1] // 2)) 147 | else: 148 | return image.view((-1, 4, shape[-2] // 2, shape[-1] // 2)) 149 | 150 | 151 | def demosaic(image): 152 | assert isinstance(image, torch.Tensor) 153 | image = image.clamp(0.0, 1.0) * 255 154 | 155 | if image.dim() == 4: 156 | num_images = image.dim() 157 | batch_input = True 158 | else: 159 | num_images = 1 160 | batch_input = False 161 | image = image.unsqueeze(0) 162 | 163 | # Generate single channel input for opencv 164 | im_sc = torch.zeros((num_images, image.shape[-2] * 2, image.shape[-1] * 2, 1)) 165 | im_sc[:, ::2, ::2, 0] = image[:, 0, :, :] 166 | im_sc[:, ::2, 1::2, 0] = image[:, 1, :, :] 167 | im_sc[:, 1::2, ::2, 0] = image[:, 2, :, :] 168 | im_sc[:, 1::2, 1::2, 0] = image[:, 3, :, :] 169 | 170 | im_sc = im_sc.numpy().astype(np.uint8) 171 | 172 | out = [] 173 | 174 | for im in im_sc: 175 | # cv.imwrite('frames/tmp.png', im) 176 | im_dem_np = cv.cvtColor(im, cv.COLOR_BAYER_BG2RGB)#_VNG) 177 | 178 | # Convert to torch image 179 | im_t = df_utils.npimage_to_torch(im_dem_np, input_bgr=False) 180 | out.append(im_t) 181 | 182 | if batch_input: 183 | return torch.stack(out, dim=0) 184 | else: 185 | return out[0] 186 | 187 | 188 | def random_noise_levels(): 189 | """Generates random noise levels from a log-log linear distribution.""" 190 | log_min_shot_noise = math.log(0.0001) 191 | log_max_shot_noise = math.log(0.012) 192 | log_shot_noise = random.uniform(log_min_shot_noise, log_max_shot_noise) 193 | shot_noise = math.exp(log_shot_noise) 194 | 195 | line = lambda x: 2.18 * x + 1.20 196 | log_read_noise = line(log_shot_noise) + random.gauss(mu=0.0, sigma=0.26) 197 | read_noise = math.exp(log_read_noise) 198 | return shot_noise, read_noise 199 | 200 | 201 | def add_noise(image, shot_noise=0.01, read_noise=0.0005): 202 | """Adds random shot (proportional to image) and read (independent) noise.""" 203 | variance = image * shot_noise + read_noise 204 | noise = torch.FloatTensor(image.shape).normal_().to(image.device)*variance.sqrt() 205 | return image + noise 206 | 207 | 208 | def process_linear_image_rgb(image, meta_info, return_np=False): 209 | image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain']) 210 | image = apply_ccm(image, meta_info['cam2rgb']) 211 | 212 | if meta_info['gamma']: 213 | image = gamma_compression(image) 214 | 215 | if meta_info['smoothstep']: 216 | image = apply_smoothstep(image) 217 | 218 | image = image.clamp(0.0, 1.0) 219 | 220 | if return_np: 221 | image = df_utils.torch_to_npimage(image) 222 | return image 223 | 224 | 225 | def process_linear_image_raw(image, meta_info): 226 | image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain']) 227 | image = demosaic(image) 228 | image = apply_ccm(image, meta_info['cam2rgb']) 229 | 230 | if meta_info['gamma']: 231 | image = gamma_compression(image) 232 | 233 | if meta_info['smoothstep']: 234 | image = apply_smoothstep(image) 235 | return image.clamp(0.0, 1.0) 236 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/burstsr_test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import random 5 | from .burstsr_dataset import SamsungRAWImage, flatten_raw_image, pack_raw_image 6 | 7 | 8 | class BurstSRDataset(torch.utils.data.Dataset): 9 | """ Real-world burst super-resolution dataset. """ 10 | def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='test'): 11 | """ 12 | args: 13 | root : path of the root directory 14 | burst_size : Burst size. Maximum allowed burst size is 14. 15 | crop_sz: Size of the extracted crop. Maximum allowed crop size is 80 16 | center_crop: Whether to extract a random crop, or a centered crop. 17 | random_flip: Whether to apply random horizontal and vertical flip 18 | split: Can be 'train' or 'val' 19 | """ 20 | assert burst_size <= 14, 'burst_sz must be less than or equal to 14' 21 | assert crop_sz <= 80, 'crop_sz must be less than or equal to 80' 22 | assert split in ['test'] 23 | 24 | root = root + '/' + split 25 | super().__init__() 26 | 27 | self.burst_size = burst_size 28 | self.crop_sz = crop_sz 29 | self.split = split 30 | self.center_crop = center_crop 31 | self.random_flip = random_flip 32 | 33 | self.root = root 34 | 35 | self.substract_black_level = True 36 | self.white_balance = False 37 | 38 | self.burst_list = self._get_burst_list() 39 | 40 | def _get_burst_list(self): 41 | burst_list = sorted(os.listdir('{}'.format(self.root))) 42 | 43 | return burst_list 44 | 45 | def get_burst_info(self, burst_id): 46 | burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]} 47 | return burst_info 48 | 49 | def _get_raw_image(self, burst_id, im_id): 50 | raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id)) 51 | return raw_image 52 | 53 | def get_burst(self, burst_id, im_ids, info=None): 54 | frames = [self._get_raw_image(burst_id, i) for i in im_ids] 55 | 56 | if info is None: 57 | info = self.get_burst_info(burst_id) 58 | 59 | return frames, info 60 | 61 | def _sample_images(self): 62 | burst_size = 14 63 | 64 | ids = random.sample(range(1, burst_size), k=self.burst_size - 1) 65 | ids = [0, ] + ids 66 | return ids 67 | 68 | def __len__(self): 69 | return len(self.burst_list) 70 | 71 | def __getitem__(self, index): 72 | # Sample the images in the burst, in case a burst_size < 14 is used. 73 | im_ids = self._sample_images() 74 | 75 | # Read the burst images along with HR ground truth 76 | frames, meta_info = self.get_burst(index, im_ids) 77 | 78 | # Extract crop if needed 79 | if frames[0].shape()[-1] != self.crop_sz: 80 | if getattr(self, 'center_crop', False): 81 | r1 = (frames[0].shape()[-2] - self.crop_sz) // 2 82 | c1 = (frames[0].shape()[-1] - self.crop_sz) // 2 83 | else: 84 | r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz) 85 | c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz) 86 | r2 = r1 + self.crop_sz 87 | c2 = c1 + self.crop_sz 88 | 89 | frames = [im.get_crop(r1, r2, c1, c2) for im in frames] 90 | 91 | # Load the RAW image data 92 | burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level, 93 | white_balance=self.white_balance) for im in frames] 94 | 95 | if self.random_flip: 96 | burst_image_data = [flatten_raw_image(im) for im in burst_image_data] 97 | 98 | pad = [0, 0, 0, 0] 99 | if random.random() > 0.5: 100 | burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data] 101 | pad[1] = 1 102 | 103 | if random.random() > 0.5: 104 | burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data] 105 | pad[3] = 1 106 | 107 | burst_image_data = [pack_raw_image(im) for im in burst_image_data] 108 | burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data] 109 | 110 | burst_image_meta_info = frames[0].get_all_meta_data() 111 | 112 | burst_image_meta_info['black_level_subtracted'] = self.substract_black_level 113 | burst_image_meta_info['while_balance_applied'] = self.white_balance 114 | burst_image_meta_info['norm_factor'] = frames[0].norm_factor 115 | 116 | burst = torch.stack(burst_image_data, dim=0) 117 | 118 | burst_exposure = frames[0].get_exposure_time() 119 | 120 | burst_f_number = frames[0].get_f_number() 121 | 122 | burst_iso = frames[0].get_iso() 123 | 124 | burst_image_meta_info['exposure'] = burst_exposure 125 | burst_image_meta_info['f_number'] = burst_f_number 126 | burst_image_meta_info['iso'] = burst_iso 127 | 128 | burst = burst.float() 129 | 130 | meta_info_burst = burst_image_meta_info 131 | 132 | for k, v in meta_info_burst.items(): 133 | if isinstance(v, (list, tuple)): 134 | meta_info_burst[k] = torch.tensor(v) 135 | 136 | return burst, meta_info_burst -------------------------------------------------------------------------------- /datasets/synthetic_burst_test_set.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | class SyntheticBurstVal(torch.utils.data.Dataset): 7 | """ Synthetic burst validation set. The validation burst have been generated using the same synthetic pipeline as 8 | employed in SyntheticBurst dataset. 9 | """ 10 | def __init__(self, root): 11 | self.root = root 12 | self.burst_list = list(range(500)) 13 | self.burst_size = 14 14 | 15 | def __len__(self): 16 | return len(self.burst_list) 17 | 18 | def _read_burst_image(self, index, image_id): 19 | im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED) 20 | im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14) 21 | return im_t 22 | 23 | def __getitem__(self, index): 24 | """ Generates a synthetic burst 25 | args: 26 | index: Index of the burst 27 | 28 | returns: 29 | burst: LR RAW burst, a torch tensor of shape 30 | [14, 4, 48, 48] 31 | The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick. 32 | seq_name: Name of the burst sequence 33 | """ 34 | burst_name = '{:04d}'.format(index) 35 | burst = [self._read_burst_image(index, i) for i in range(self.burst_size)] 36 | burst = torch.stack(burst, 0) 37 | 38 | return burst, burst_name -------------------------------------------------------------------------------- /datasets/synthetic_burst_train_set.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | from data_processing.synthetic_burst_generation import rgb2rawburst, random_crop #syn_burst_utils 5 | import torchvision.transforms as tfm 6 | 7 | 8 | class SyntheticBurst(torch.utils.data.Dataset): 9 | """ Synthetic burst dataset for joint denoising, demosaicking, and super-resolution. RAW Burst sequences are 10 | synthetically generated on the fly as follows. First, a single image is loaded from the base_dataset. The sampled 11 | image is converted to linear sensor space using the inverse camera pipeline employed in [1]. A burst 12 | sequence is then generated by adding random translations and rotations to the converted image. The generated burst 13 | is then converted is then mosaicked, and corrupted by random noise to obtain the RAW burst. 14 | 15 | [1] Unprocessing Images for Learned Raw Denoising, Brooks, Tim and Mildenhall, Ben and Xue, Tianfan and Chen, 16 | Jiawen and Sharlet, Dillon and Barron, Jonathan T, CVPR 2019 17 | """ 18 | def __init__(self, base_dataset, burst_size=8, crop_sz=384, transform=tfm.ToTensor()): 19 | self.base_dataset = base_dataset 20 | 21 | self.burst_size = burst_size 22 | self.crop_sz = crop_sz 23 | self.transform = transform 24 | 25 | self.downsample_factor = 4 26 | self.burst_transformation_params = {'max_translation': 24.0, 27 | 'max_rotation': 1.0, 28 | 'max_shear': 0.0, 29 | 'max_scale': 0.0, 30 | 'border_crop': 24} 31 | 32 | self.image_processing_params = {'random_ccm': True, 'random_gains': True, 'smoothstep': True, 33 | 'gamma': True, 34 | 'add_noise': True} 35 | self.interpolation_type = 'bilinear' 36 | 37 | def __len__(self): 38 | return len(self.base_dataset) 39 | 40 | def __getitem__(self, index): 41 | """ Generates a synthetic burst 42 | args: 43 | index: Index of the image in the base_dataset used to generate the burst 44 | 45 | returns: 46 | burst: Generated LR RAW burst, a torch tensor of shape 47 | [burst_size, 4, self.crop_sz / (2*self.downsample_factor), self.crop_sz / (2*self.downsample_factor)] 48 | The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick. 49 | The extra factor 2 in the denominator (2*self.downsample_factor) corresponds to the mosaicking 50 | operation. 51 | 52 | frame_gt: The HR RGB ground truth in the linear sensor space, a torch tensor of shape 53 | [3, self.crop_sz, self.crop_sz] 54 | 55 | flow_vectors: The ground truth flow vectors between a burst image and the base image (i.e. the first image in the burst). 56 | The flow_vectors can be used to warp the burst images to the base frame, using the 'warp' 57 | function in utils.warp package. 58 | flow_vectors is torch tensor of shape 59 | [burst_size, 2, self.crop_sz / self.downsample_factor, self.crop_sz / self.downsample_factor]. 60 | Note that the flow_vectors are in the LR RGB space, before mosaicking. Hence it has twice 61 | the number of rows and columns, compared to the output burst. 62 | 63 | NOTE: The flow_vectors are only available during training for the purpose of using any 64 | auxiliary losses if needed. The flow_vectors will NOT be provided for the bursts in the 65 | test set 66 | 67 | meta_info: A dictionary containing the parameters used to generate the synthetic burst. 68 | """ 69 | frame = self.base_dataset[index] 70 | 71 | # Augmentation, e.g. convert to tensor 72 | if self.transform is not None: 73 | # frame = Image.fromarray(frame) 74 | frame = self.transform(frame) 75 | 76 | # Extract a random crop from the image 77 | crop_sz = self.crop_sz + 2 * self.burst_transformation_params.get('border_crop', 0) 78 | frame_crop = random_crop(frame, crop_sz) 79 | 80 | # Generate RAW burst 81 | burst, frame_gt, burst_rgb, flow_vectors, meta_info = rgb2rawburst(frame_crop, 82 | self.burst_size, 83 | self.downsample_factor, 84 | burst_transformation_params=self.burst_transformation_params, 85 | image_processing_params=self.image_processing_params, 86 | interpolation_type=self.interpolation_type 87 | ) 88 | 89 | if self.burst_transformation_params.get('border_crop') is not None: 90 | border_crop = self.burst_transformation_params.get('border_crop') 91 | frame_gt = frame_gt[:, border_crop:-border_crop, border_crop:-border_crop] 92 | 93 | return burst, frame_gt, flow_vectors, meta_info 94 | -------------------------------------------------------------------------------- /datasets/synthetic_burst_val_set.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | class SyntheticBurstVal(torch.utils.data.Dataset): 7 | """ Synthetic burst validation set. The validation burst have been generated using the same synthetic pipeline as 8 | employed in SyntheticBurst dataset. 9 | """ 10 | def __init__(self, root): 11 | self.root = root 12 | self.burst_list = list(range(100)) 13 | self.burst_size = 14 14 | 15 | def __len__(self): 16 | return len(self.burst_list) 17 | 18 | def _read_burst_image(self, index, image_id): 19 | im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED) 20 | im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14) 21 | return im_t 22 | 23 | def __getitem__(self, index): 24 | """ Generates a synthetic burst 25 | args: 26 | index: Index of the burst 27 | 28 | returns: 29 | burst: LR RAW burst, a torch tensor of shape 30 | [14, 4, 48, 48] 31 | The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick. 32 | seq_name: Name of the burst sequence 33 | """ 34 | burst_name = '{:04d}'.format(index) 35 | burst = [self._read_burst_image(index, i) for i in range(self.burst_size)] 36 | burst = torch.stack(burst, 0) 37 | 38 | return burst, burst_name 39 | -------------------------------------------------------------------------------- /datasets/zurich_raw2rgb_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from cv2 import imread 4 | 5 | 6 | class ZurichRAW2RGB(torch.utils.data.Dataset): 7 | """ Canon RGB images from the "Zurich RAW to RGB mapping" dataset. You can download the full 8 | dataset (22 GB) from http://people.ee.ethz.ch/~ihnatova/pynet.html#dataset. Alternatively, you can only download the 9 | Canon RGB images (5.5 GB) from https://data.vision.ee.ethz.ch/bhatg/zurich-raw-to-rgb.zip 10 | """ 11 | def __init__(self, root, split='train'): 12 | super().__init__() 13 | 14 | if split in ['train', 'test']: 15 | self.img_pth = os.path.join(root, split, 'canon') 16 | else: 17 | raise Exception('Unknown split {}'.format(split)) 18 | 19 | self.image_list = self._get_image_list(split) 20 | 21 | def _get_image_list(self, split): 22 | if split == 'train': 23 | image_list = ['{:d}.jpg'.format(i) for i in range(46839)] 24 | elif split == 'test': 25 | image_list = ['{:d}.jpg'.format(i) for i in range(1204)] 26 | else: 27 | raise Exception 28 | 29 | return image_list 30 | 31 | def _get_image(self, im_id): 32 | path = os.path.join(self.img_pth, self.image_list[im_id]) 33 | img = imread(path) 34 | return img 35 | 36 | def get_image(self, im_id): 37 | frame = self._get_image(im_id) 38 | 39 | return frame 40 | 41 | def __len__(self): 42 | return len(self.image_list) 43 | 44 | def __getitem__(self, index): 45 | frame = self._get_image(index) 46 | 47 | return frame 48 | -------------------------------------------------------------------------------- /figs/ts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/figs/ts.png -------------------------------------------------------------------------------- /loss/Charbonnier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """L1 charbonnier loss.""" 7 | 8 | def __init__(self, epsilon=1e-3, reduce=True): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = epsilon * epsilon 11 | self.reduce = reduce 12 | 13 | def forward(self, X, Y): 14 | diff = torch.add(X, -Y) 15 | error = torch.sqrt(diff * diff + self.eps) 16 | if self.reduce: 17 | loss = torch.mean(error) 18 | else: 19 | loss = error 20 | return loss -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | class Loss(nn.modules.loss._Loss): 15 | def __init__(self, args, ckp): 16 | super(Loss, self).__init__() 17 | if args.local_rank == 0: 18 | print('Preparing loss function:') 19 | 20 | self.n_GPUs = args.n_GPUs 21 | self.loss = [] 22 | self.loss_module = nn.ModuleList() 23 | for loss in args.loss.split('+'): 24 | weight, loss_type = loss.split('*') 25 | if loss_type == 'MSE': 26 | loss_function = nn.MSELoss() 27 | elif loss_type == 'L1': 28 | loss_function = nn.L1Loss() 29 | elif loss_type.find('VGG') >= 0: 30 | module = import_module('loss.vgg') 31 | loss_function = getattr(module, 'VGG')( 32 | loss_type[3:], 33 | rgb_range=args.rgb_range 34 | ) 35 | elif loss_type.find('GAN') >= 0: 36 | module = import_module('loss.adversarial') 37 | loss_function = getattr(module, 'Adversarial')( 38 | args, 39 | loss_type 40 | ) 41 | elif loss_type == 'FILTER': 42 | module = import_module('loss.filter') 43 | loss_function = getattr(module, 'Filter')(args) 44 | elif loss_type == 'SSIM': 45 | module = import_module('loss.mssim') 46 | loss_function = getattr(module, 'SSIM')(args) 47 | elif loss_type == 'MSSSIM': 48 | module = import_module('loss.mssim') 49 | loss_function = getattr(module, 'MSSSIM')(args) 50 | 51 | self.loss.append({ 52 | 'type': loss_type, 53 | 'weight': float(weight), 54 | 'function': loss_function} 55 | ) 56 | if loss_type.find('GAN') >= 0: 57 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) 58 | 59 | if len(self.loss) > 1: 60 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 61 | 62 | for l in self.loss: 63 | if l['function'] is not None: 64 | if args.local_rank == 0: 65 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 66 | self.loss_module.append(l['function']) 67 | 68 | self.log = torch.Tensor() 69 | 70 | device = torch.device('cpu' if args.cpu else 'cuda') 71 | self.loss_module.to(device) 72 | if args.precision == 'half': self.loss_module.half() 73 | if not args.cpu and args.n_GPUs > 1: 74 | self.loss_module = nn.DataParallel( 75 | self.loss_module, range(args.n_GPUs) 76 | ) 77 | 78 | if args.load != '': self.load(ckp.dir, cpu=args.cpu) 79 | 80 | def forward(self, sr, hr): 81 | losses = [] 82 | for i, l in enumerate(self.loss): 83 | if l['function'] is not None: 84 | loss = l['function'](sr, hr) 85 | effective_loss = l['weight'] * loss 86 | losses.append(effective_loss) 87 | self.log[-1, i] += effective_loss.item() 88 | elif l['type'] == 'DIS': 89 | self.log[-1, i] += self.loss[i - 1]['function'].loss 90 | 91 | loss_sum = sum(losses) 92 | if len(self.loss) > 1: 93 | self.log[-1, -1] += loss_sum.item() 94 | 95 | return loss_sum 96 | 97 | def step(self): 98 | for l in self.get_loss_module(): 99 | if hasattr(l, 'scheduler'): 100 | l.scheduler.step() 101 | 102 | def start_log(self): 103 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 104 | 105 | def end_log(self, n_batches): 106 | self.log[-1].div_(n_batches) 107 | 108 | def display_loss(self, batch): 109 | n_samples = batch + 1 110 | log = [] 111 | for l, c in zip(self.loss, self.log[-1]): 112 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 113 | 114 | return ''.join(log) 115 | 116 | def plot_loss(self, apath, epoch): 117 | axis = np.linspace(1, epoch, epoch) 118 | for i, l in enumerate(self.loss): 119 | label = '{} Loss'.format(l['type']) 120 | fig = plt.figure() 121 | plt.title(label) 122 | plt.plot(axis, self.log[:, i].numpy(), label=label) 123 | plt.legend() 124 | plt.xlabel('Epochs') 125 | plt.ylabel('Loss') 126 | plt.grid(True) 127 | plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) 128 | plt.close(fig) 129 | 130 | def get_loss_module(self): 131 | if self.n_GPUs == 1: 132 | return self.loss_module 133 | else: 134 | return self.loss_module.module 135 | 136 | def save(self, apath): 137 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) 138 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 139 | 140 | def load(self, apath, cpu=False): 141 | if cpu: 142 | kwargs = {'map_location': lambda storage, loc: storage} 143 | else: 144 | kwargs = {} 145 | 146 | self.load_state_dict(torch.load( 147 | os.path.join(apath, 'loss.pt'), 148 | **kwargs 149 | )) 150 | self.log = torch.load(os.path.join(apath, 'loss_log.pt')) 151 | for l in self.get_loss_module(): 152 | if hasattr(l, 'scheduler'): 153 | for _ in range(len(self.log)): l.scheduler.step() 154 | 155 | -------------------------------------------------------------------------------- /loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import utility 2 | from types import SimpleNamespace 3 | 4 | from model import common 5 | from loss import discriminator 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | class Adversarial(nn.Module): 13 | def __init__(self, args, gan_type): 14 | super(Adversarial, self).__init__() 15 | self.gan_type = gan_type 16 | self.gan_k = args.gan_k 17 | self.dis = discriminator.Discriminator(args) 18 | # if gan_type == 'WGAN_GP': 19 | if True: 20 | # see https://arxiv.org/pdf/1704.00028.pdf pp.4 21 | optim_dict = { 22 | 'optimizer': 'ADAM', 23 | 'betas': (0.5, 0.9), 24 | 'epsilon': 1e-8, 25 | 'lr': 1e-5, 26 | 'weight_decay': args.weight_decay, 27 | 'decay': args.decay, 28 | 'gamma': args.gamma 29 | } 30 | optim_args = SimpleNamespace(**optim_dict) 31 | else: 32 | optim_args = args 33 | 34 | self.optimizer = utility.make_optimizer(optim_args, self.dis) 35 | 36 | def forward(self, fake, real): 37 | # updating discriminator... 38 | self.loss = 0 39 | fake_detach = fake.detach() # do not backpropagate through G 40 | for _ in range(self.gan_k): 41 | self.optimizer.zero_grad() 42 | # d: B x 1 tensor 43 | d_fake = self.dis(fake_detach) 44 | d_real = self.dis(real) 45 | retain_graph = False 46 | if self.gan_type in ['GAN', 'SNGAN']: 47 | loss_d = self.bce(d_real, d_fake) 48 | elif self.gan_type.find('WGAN') >= 0: 49 | loss_d = (d_fake - d_real).mean() 50 | if self.gan_type.find('GP') >= 0: 51 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) 52 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 53 | hat.requires_grad = True 54 | d_hat = self.dis(hat) 55 | gradients = torch.autograd.grad( 56 | outputs=d_hat.sum(), inputs=hat, 57 | retain_graph=True, create_graph=True, only_inputs=True 58 | )[0] 59 | gradients = gradients.view(gradients.size(0), -1) 60 | gradient_norm = gradients.norm(2, dim=1) 61 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 62 | loss_d += gradient_penalty 63 | # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks 64 | elif self.gan_type == 'RGAN': 65 | better_real = d_real - d_fake.mean(dim=0, keepdim=True) 66 | better_fake = d_fake - d_real.mean(dim=0, keepdim=True) 67 | loss_d = self.bce(better_real, better_fake) 68 | retain_graph = True 69 | 70 | # Discriminator update 71 | self.loss += loss_d.item() 72 | loss_d.backward(retain_graph=retain_graph) 73 | self.optimizer.step() 74 | 75 | if self.gan_type == 'WGAN': 76 | for p in self.dis.parameters(): 77 | p.data.clamp_(-1, 1) 78 | 79 | self.loss /= self.gan_k 80 | 81 | # updating generator... 82 | d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is 83 | if self.gan_type in ['GAN', 'SNGAN']: 84 | label_real = torch.ones_like(d_fake_bp) 85 | loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) 86 | elif self.gan_type.find('WGAN') >= 0: 87 | loss_g = -d_fake_bp.mean() 88 | elif self.gan_type == 'RGAN': 89 | better_real = d_real.detach() - d_fake_bp.mean(dim=0, keepdim=True) 90 | better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True).detach() 91 | loss_g = self.bce(better_fake, better_real) 92 | 93 | # Generator loss 94 | return loss_g 95 | 96 | def state_dict(self, *args, **kwargs): 97 | state_discriminator = self.dis.state_dict(*args, **kwargs) 98 | state_optimizer = self.optimizer.state_dict() 99 | 100 | return dict(**state_discriminator, **state_optimizer) 101 | 102 | def bce(self, real, fake): 103 | label_real = torch.ones_like(real) 104 | label_fake = torch.zeros_like(fake) 105 | bce_real = F.binary_cross_entropy_with_logits(real, label_real) 106 | bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) 107 | bce_loss = bce_real + bce_fake 108 | return bce_loss 109 | 110 | # Some references 111 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py 112 | # OR 113 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py 114 | -------------------------------------------------------------------------------- /loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | ''' 7 | output is not normalized 8 | ''' 9 | def __init__(self, args, gan_type='GAN'): 10 | super(Discriminator, self).__init__() 11 | 12 | in_channels = args.n_colors 13 | out_channels = 32 14 | depth = 6 15 | 16 | def _block(_in_channels, _out_channels, stride=1): 17 | 18 | Conv = nn.Conv2d( 19 | _in_channels, 20 | _out_channels, 21 | 3, 22 | padding=1, 23 | stride=stride, 24 | bias=False 25 | ) 26 | 27 | if gan_type == 'SNGAN': 28 | return nn.Sequential( 29 | spectral_norm(Conv), 30 | nn.BatchNorm2d(_out_channels), 31 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 32 | ) 33 | else: 34 | return nn.Sequential( 35 | Conv, 36 | nn.BatchNorm2d(_out_channels), 37 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 38 | ) 39 | 40 | m_features = [_block(in_channels, out_channels)] 41 | for i in range(depth): 42 | in_channels = out_channels 43 | # if i % 2 == 1: 44 | # stride = 1 45 | # out_channels *= 2 46 | # else: 47 | out_channels *= 2 48 | stride = 2 49 | m_features.append(_block(in_channels, out_channels, stride=stride)) 50 | 51 | patch_size = args.patch_size // 2**(depth-1) 52 | 53 | # print(out_channels, patch_size) 54 | 55 | m_classifier = [ 56 | nn.Flatten(), 57 | nn.Linear(out_channels*patch_size**2, 512), 58 | nn.LeakyReLU(0.2, True), 59 | nn.Linear(512, 1) 60 | ] 61 | 62 | self.features = nn.Sequential(*m_features) 63 | self.classifier = nn.Sequential(*m_classifier) 64 | 65 | def forward(self, x): 66 | features = self.features(x) 67 | # print(features.shape) 68 | output = self.classifier(features) 69 | 70 | return output 71 | 72 | -------------------------------------------------------------------------------- /loss/filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Filter(nn.Module): 6 | def __init__(self, args): 7 | super().__init__() 8 | self.args = args 9 | 10 | kernel = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]]) 11 | self.conv = nn.Conv2d(args.n_colors, args.n_colors, 3, 3) 12 | with torch.no_grad(): 13 | self.conv.weight.copy_(kernel.float()) 14 | self.loss = nn.L1Loss() 15 | 16 | def forward(self, x, y): 17 | preds_x = self.conv(x) 18 | preds_y = self.conv(y) 19 | 20 | return self.loss(preds_x, preds_y) 21 | -------------------------------------------------------------------------------- /loss/hist_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class HistEntropy(nn.Module): 6 | def __init__(self, args): 7 | super().__init__() 8 | self.args = args 9 | 10 | def forward(self, x): 11 | p = torch.softmax(x, dim=1) 12 | logp = torch.log_softmax(x, dim=1) 13 | 14 | entropy = (-p * logp).sum(dim=(2, 3)).mean() 15 | 16 | return entropy 17 | -------------------------------------------------------------------------------- /loss/mssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | import numpy as np 5 | 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | 12 | def create_window(window_size, channel=1): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 16 | return window 17 | 18 | 19 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 20 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 21 | if val_range is None: 22 | if torch.max(img1) > 128: 23 | max_val = 255 24 | else: 25 | max_val = 1 26 | 27 | if torch.min(img1) < -0.5: 28 | min_val = -1 29 | else: 30 | min_val = 0 31 | L = max_val - min_val 32 | else: 33 | L = val_range 34 | 35 | padd = 0 36 | (_, channel, height, width) = img1.size() 37 | if window is None: 38 | real_size = min(window_size, height, width) 39 | window = create_window(real_size, channel=channel).to(img1.device) 40 | 41 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 42 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 43 | 44 | mu1_sq = mu1.pow(2) 45 | mu2_sq = mu2.pow(2) 46 | mu1_mu2 = mu1 * mu2 47 | 48 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 49 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 50 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 51 | 52 | C1 = (0.01 * L) ** 2 53 | C2 = (0.03 * L) ** 2 54 | 55 | v1 = 2.0 * sigma12 + C2 56 | v2 = sigma1_sq + sigma2_sq + C2 57 | cs = torch.mean(v1 / v2) # contrast sensitivity 58 | 59 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 60 | 61 | if size_average: 62 | ret = ssim_map.mean() 63 | else: 64 | ret = ssim_map.mean(1).mean(1).mean(1) 65 | 66 | if full: 67 | return ret, cs 68 | return ret 69 | 70 | 71 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None): 72 | device = img1.device 73 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 74 | levels = weights.size()[0] 75 | ssims = [] 76 | mcs = [] 77 | for _ in range(levels): 78 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 79 | 80 | # Relu normalize (not compliant with original definition) 81 | if normalize == "relu": 82 | ssims.append(torch.relu(sim)) 83 | mcs.append(torch.relu(cs)) 84 | else: 85 | ssims.append(sim) 86 | mcs.append(cs) 87 | 88 | img1 = F.avg_pool2d(img1, (2, 2)) 89 | img2 = F.avg_pool2d(img2, (2, 2)) 90 | 91 | ssims = torch.stack(ssims) 92 | mcs = torch.stack(mcs) 93 | 94 | # Simple normalize (not compliant with original definition) 95 | # TODO: remove support for normalize == True (kept for backward support) 96 | if normalize == "simple" or normalize == True: 97 | ssims = (ssims + 1) / 2 98 | mcs = (mcs + 1) / 2 99 | 100 | pow1 = mcs ** weights 101 | pow2 = ssims ** weights 102 | 103 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 104 | output = torch.prod(pow1[:-1] * pow2[-1]) 105 | return output 106 | 107 | 108 | # Classes to re-use window 109 | class SSIM(torch.nn.Module): 110 | def __init__(self, window_size=11, size_average=True, val_range=None): 111 | super(SSIM, self).__init__() 112 | self.window_size = window_size 113 | self.size_average = size_average 114 | self.val_range = val_range 115 | 116 | # Assume 1 channel for SSIM 117 | self.channel = 1 118 | self.window = create_window(window_size) 119 | 120 | def forward(self, img1, img2): 121 | (_, channel, _, _) = img1.size() 122 | 123 | if channel == self.channel and self.window.dtype == img1.dtype: 124 | window = self.window 125 | else: 126 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 127 | self.window = window 128 | self.channel = channel 129 | 130 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 131 | 132 | class MSSSIM(torch.nn.Module): 133 | def __init__(self, window_size=11, size_average=True, channel=3): 134 | super(MSSSIM, self).__init__() 135 | self.window_size = window_size 136 | self.size_average = size_average 137 | self.channel = channel 138 | 139 | def forward(self, img1, img2): 140 | # TODO: store window between calls if possible 141 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) -------------------------------------------------------------------------------- /loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, conv_index, rgb_range=1): 10 | super(VGG, self).__init__() 11 | vgg_features = models.vgg19(pretrained=True).features 12 | modules = [m for m in vgg_features] 13 | if conv_index.find('22') >= 0: 14 | self.vgg = nn.Sequential(*modules[:8]) 15 | elif conv_index.find('54') >= 0: 16 | self.vgg = nn.Sequential(*modules[:35]) 17 | 18 | vgg_mean = (0.485, 0.456, 0.406) 19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | # x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | sr = sr.repeat(1, 3, 1, 1) 31 | hr = hr.repeat(1, 3, 1, 1) 32 | 33 | vgg_sr = _forward(sr) 34 | with torch.no_grad(): 35 | vgg_hr = _forward(hr.detach()) 36 | 37 | loss = F.mse_loss(vgg_sr, vgg_hr) 38 | 39 | return loss 40 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | from torchvision import transforms as T 6 | 7 | import utility 8 | import model 9 | import loss 10 | from option import args 11 | from trainer import Trainer 12 | from datasets.synthetic_burst_train_set import SyntheticBurst 13 | from datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB 14 | import torch.multiprocessing as mp 15 | import torch.backends.cudnn as cudnn 16 | import torch.utils.data.distributed 17 | 18 | try: 19 | import apex 20 | from apex.parallel import DistributedDataParallel as DDP 21 | from apex.fp16_utils import * 22 | from apex import amp, optimizers 23 | from apex.multi_tensor_apply import multi_tensor_applier 24 | except ImportError: 25 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") 26 | 27 | 28 | def init_seeds(seed=0, cuda_deterministic=True): 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 33 | if cuda_deterministic: # slower, more reproducible 34 | cudnn.deterministic = True 35 | cudnn.benchmark = False 36 | else: # faster, less reproducible 37 | cudnn.deterministic = False 38 | cudnn.benchmark = True 39 | 40 | 41 | checkpoint = utility.checkpoint(args) 42 | 43 | 44 | def main(): 45 | if args.n_GPUs > 1: 46 | mp.spawn(main_worker, nprocs=args.n_GPUs, args=(args.n_GPUs, args)) 47 | else: 48 | main_worker(0, args.n_GPUs, args) 49 | 50 | 51 | def main_worker(local_rank, nprocs, args): 52 | if checkpoint.ok: 53 | args.local_rank = local_rank 54 | if nprocs > 1: 55 | init_seeds(local_rank+1) 56 | cudnn.benchmark = True 57 | utility.setup(local_rank, nprocs) 58 | torch.cuda.set_device(args.local_rank) 59 | 60 | batch_size = int(args.batch_size / nprocs) 61 | train_zurich_raw2rgb = ZurichRAW2RGB(root=args.root, split='train') 62 | train_data = SyntheticBurst(train_zurich_raw2rgb, burst_size=args.burst_size, crop_sz=args.patch_size) 63 | 64 | valid_zurich_raw2rgb = ZurichRAW2RGB(root=args.root, split='test') 65 | valid_data = SyntheticBurst(valid_zurich_raw2rgb, burst_size=args.burst_size, crop_sz=384) 66 | 67 | if nprocs > 1: 68 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) 69 | valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data, shuffle=False) 70 | train_loader = DataLoader(dataset=train_data, batch_size=batch_size, num_workers=8, 71 | pin_memory=True, drop_last=True, sampler=train_sampler) 72 | valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, num_workers=4, 73 | pin_memory=True, drop_last=True, sampler=valid_sampler) 74 | else: 75 | train_sampler = None 76 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=8, 77 | shuffle=True, pin_memory=True, drop_last=True) # args.cpus 78 | valid_loader = DataLoader(dataset=valid_data, batch_size=args.batch_size, num_workers=4, shuffle=False, 79 | pin_memory=True, drop_last=True) # args.cpus 80 | 81 | 82 | _model = model.Model(args, checkpoint) 83 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None 84 | t = Trainer(args, train_loader, train_sampler, valid_loader, _model, _loss, checkpoint) 85 | while not t.terminate(): 86 | t.train() 87 | 88 | del _model 89 | del _loss 90 | del train_loader 91 | del valid_loader 92 | 93 | utility.cleanup() 94 | 95 | checkpoint.done() 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /model/DCNv2/DCNv2.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: DCNv2 3 | Version: 0.1 4 | Summary: deformable convolutional networks 5 | Home-page: https://github.com/charlesshang/DCNv2 6 | Author: charlesshang 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /model/DCNv2/DCNv2.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.py 3 | /data/work/pylibs/DCNv2-pytorch_1.6/src/vision.cpp 4 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_cpu.cpp 5 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_im2col_cpu.cpp 6 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_psroi_pooling_cpu.cpp 7 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_cuda.cu 8 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_im2col_cuda.cu 9 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_psroi_pooling_cuda.cu 10 | DCNv2.egg-info/PKG-INFO 11 | DCNv2.egg-info/SOURCES.txt 12 | DCNv2.egg-info/dependency_links.txt 13 | DCNv2.egg-info/top_level.txt -------------------------------------------------------------------------------- /model/DCNv2/DCNv2.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/DCNv2/DCNv2.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | _ext 2 | -------------------------------------------------------------------------------- /model/DCNv2/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Charles Shang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /model/DCNv2/README.md: -------------------------------------------------------------------------------- 1 | ## Deformable Convolutional Networks V2 with Pytorch 1.0 2 | 3 | ### Build 4 | ```bash 5 | ./make.sh # build 6 | python test.py # run examples and gradient check 7 | ``` 8 | 9 | ### An Example 10 | - deformable conv 11 | ```python 12 | from dcn_v2 import DCN 13 | input = torch.randn(2, 64, 128, 128).cuda() 14 | # wrap all things (offset and mask) in DCN 15 | dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda() 16 | output = dcn(input) 17 | print(output.shape) 18 | ``` 19 | - deformable roi pooling 20 | ```python 21 | from dcn_v2 import DCNPooling 22 | input = torch.randn(2, 32, 64, 64).cuda() 23 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 24 | x = torch.randint(256, (20, 1)).cuda().float() 25 | y = torch.randint(256, (20, 1)).cuda().float() 26 | w = torch.randint(64, (20, 1)).cuda().float() 27 | h = torch.randint(64, (20, 1)).cuda().float() 28 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 29 | 30 | # mdformable pooling (V2) 31 | # wrap all things (offset and mask) in DCNPooling 32 | dpooling = DCNPooling(spatial_scale=1.0 / 4, 33 | pooled_size=7, 34 | output_dim=32, 35 | no_trans=False, 36 | group_size=1, 37 | trans_std=0.1).cuda() 38 | 39 | dout = dpooling(input, rois) 40 | ``` 41 | ### Note 42 | Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with, 43 | ```bash 44 | git checkout pytorch_0.4 45 | ``` 46 | 47 | ### Known Issues: 48 | 49 | - [x] Gradient check w.r.t offset (solved) 50 | - [ ] Backward is not reentrant (minor) 51 | 52 | This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). 53 | 54 | I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes. 55 | However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some 56 | non-differential points? 57 | 58 | Update: all gradient check passes with double precision. 59 | 60 | Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for 61 | float `<1e-15` for double), 62 | so it may not be a serious problem (?) 63 | 64 | Please post an issue or PR if you have any comments. 65 | -------------------------------------------------------------------------------- /model/DCNv2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/__init__.py -------------------------------------------------------------------------------- /model/DCNv2/__pycache__/dcn_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/__pycache__/dcn_v2.cpython-37.pyc -------------------------------------------------------------------------------- /model/DCNv2/build/lib.linux-x86_64-3.7/_ext.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/lib.linux-x86_64-3.7/_ext.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_cpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_cpu.o -------------------------------------------------------------------------------- /model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_im2col_cpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_im2col_cpu.o -------------------------------------------------------------------------------- /model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_psroi_pooling_cpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_psroi_pooling_cpu.o -------------------------------------------------------------------------------- /model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_cuda.o -------------------------------------------------------------------------------- /model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_im2col_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_im2col_cuda.o -------------------------------------------------------------------------------- /model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_psroi_pooling_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_psroi_pooling_cuda.o -------------------------------------------------------------------------------- /model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/vision.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/vision.o -------------------------------------------------------------------------------- /model/DCNv2/dist/DCNv2-0.1-py3.7-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/dist/DCNv2-0.1-py3.7-linux-x86_64.egg -------------------------------------------------------------------------------- /model/DCNv2/files.txt: -------------------------------------------------------------------------------- 1 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.cpython-37m-x86_64-linux-gnu.so 2 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.py 3 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/PKG-INFO 4 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/SOURCES.txt 5 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/dependency_links.txt 6 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/native_libs.txt 7 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/not-zip-safe 8 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/top_level.txt 9 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/__pycache__/_ext.cpython-37.pyc 10 | -------------------------------------------------------------------------------- /model/DCNv2/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python setup.py build develop 3 | -------------------------------------------------------------------------------- /model/DCNv2/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | 6 | import torch 7 | from setuptools import find_packages, setup 8 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 9 | 10 | requirements = ["torch", "torchvision"] 11 | 12 | 13 | def get_extensions(): 14 | this_dir = os.path.dirname(os.path.abspath(__file__)) 15 | extensions_dir = os.path.join(this_dir, "src") 16 | 17 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 18 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 19 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 20 | 21 | os.environ["CC"] = "g++" 22 | sources = main_file + source_cpu 23 | extension = CppExtension 24 | extra_compile_args = {"cxx": []} 25 | define_macros = [] 26 | 27 | if True: 28 | extension = CUDAExtension 29 | sources += source_cuda 30 | define_macros += [("WITH_CUDA", None)] 31 | extra_compile_args["nvcc"] = [ 32 | "-DCUDA_HAS_FP16=1", 33 | "-D__CUDA_NO_HALF_OPERATORS__", 34 | "-D__CUDA_NO_HALF_CONVERSIONS__", 35 | "-D__CUDA_NO_HALF2_OPERATORS__", 36 | ] 37 | else: 38 | # raise NotImplementedError('Cuda is not available') 39 | pass 40 | 41 | sources = [os.path.join(extensions_dir, s) for s in sources] 42 | include_dirs = [extensions_dir] 43 | ext_modules = [ 44 | extension( 45 | "_ext", 46 | sources, 47 | include_dirs=include_dirs, 48 | define_macros=define_macros, 49 | extra_compile_args=extra_compile_args, 50 | ) 51 | ] 52 | return ext_modules 53 | 54 | 55 | setup( 56 | name="DCNv2", 57 | version="0.1", 58 | author="charlesshang", 59 | url="https://github.com/charlesshang/DCNv2", 60 | description="deformable convolutional networks", 61 | packages=find_packages(exclude=("configs", "tests")), 62 | # install_requires=requirements, 63 | ext_modules=get_extensions(), 64 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 65 | ) 66 | -------------------------------------------------------------------------------- /model/DCNv2/src/cpu/dcn_v2_im2col_cpu.h: -------------------------------------------------------------------------------- 1 | 2 | /*! 3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, this 28 | * list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 51 | * 52 | * Copyright (c) 2018 Microsoft 53 | * Licensed under The MIT License [see LICENSE for details] 54 | * \file modulated_deformable_im2col.h 55 | * \brief Function definitions of converting an image to 56 | * column matrix based on kernel, padding, dilation, and offset. 57 | * These functions are mainly used in deformable convolution operators. 58 | * \ref: https://arxiv.org/abs/1811.11168 59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 60 | */ 61 | 62 | /***************** Adapted by Charles Shang *********************/ 63 | // modified from the CUDA version for CPU use by Daniel K. Suhendro 64 | 65 | #ifndef DCN_V2_IM2COL_CPU 66 | #define DCN_V2_IM2COL_CPU 67 | 68 | #ifdef __cplusplus 69 | extern "C" 70 | { 71 | #endif 72 | 73 | void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask, 74 | const int batch_size, const int channels, const int height_im, const int width_im, 75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 77 | const int dilation_h, const int dilation_w, 78 | const int deformable_group, float *data_col); 79 | 80 | void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask, 81 | const int batch_size, const int channels, const int height_im, const int width_im, 82 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 83 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 84 | const int dilation_h, const int dilation_w, 85 | const int deformable_group, float *grad_im); 86 | 87 | void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 88 | const int batch_size, const int channels, const int height_im, const int width_im, 89 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 90 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 91 | const int dilation_h, const int dilation_w, 92 | const int deformable_group, 93 | float *grad_offset, float *grad_mask); 94 | 95 | #ifdef __cplusplus 96 | } 97 | #endif 98 | 99 | #endif -------------------------------------------------------------------------------- /model/DCNv2/src/cpu/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | dcn_v2_cpu_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cpu_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /model/DCNv2/src/cuda/dcn_v2_im2col_cuda.h: -------------------------------------------------------------------------------- 1 | 2 | /*! 3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, this 28 | * list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 51 | * 52 | * Copyright (c) 2018 Microsoft 53 | * Licensed under The MIT License [see LICENSE for details] 54 | * \file modulated_deformable_im2col.h 55 | * \brief Function definitions of converting an image to 56 | * column matrix based on kernel, padding, dilation, and offset. 57 | * These functions are mainly used in deformable convolution operators. 58 | * \ref: https://arxiv.org/abs/1811.11168 59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 60 | */ 61 | 62 | /***************** Adapted by Charles Shang *********************/ 63 | 64 | #ifndef DCN_V2_IM2COL_CUDA 65 | #define DCN_V2_IM2COL_CUDA 66 | 67 | #ifdef __cplusplus 68 | extern "C" 69 | { 70 | #endif 71 | 72 | void modulated_deformable_im2col_cuda(cudaStream_t stream, 73 | const float *data_im, const float *data_offset, const float *data_mask, 74 | const int batch_size, const int channels, const int height_im, const int width_im, 75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 77 | const int dilation_h, const int dilation_w, 78 | const int deformable_group, float *data_col); 79 | 80 | void modulated_deformable_col2im_cuda(cudaStream_t stream, 81 | const float *data_col, const float *data_offset, const float *data_mask, 82 | const int batch_size, const int channels, const int height_im, const int width_im, 83 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 84 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 85 | const int dilation_h, const int dilation_w, 86 | const int deformable_group, float *grad_im); 87 | 88 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, 89 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 90 | const int batch_size, const int channels, const int height_im, const int width_im, 91 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 92 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 93 | const int dilation_h, const int dilation_w, 94 | const int deformable_group, 95 | float *grad_offset, float *grad_mask); 96 | 97 | #ifdef __cplusplus 98 | } 99 | #endif 100 | 101 | #endif -------------------------------------------------------------------------------- /model/DCNv2/src/cuda/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | at::Tensor 5 | dcn_v2_cuda_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cuda_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /model/DCNv2/src/dcn_v2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu/vision.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/vision.h" 7 | #endif 8 | 9 | at::Tensor 10 | dcn_v2_forward(const at::Tensor &input, 11 | const at::Tensor &weight, 12 | const at::Tensor &bias, 13 | const at::Tensor &offset, 14 | const at::Tensor &mask, 15 | const int kernel_h, 16 | const int kernel_w, 17 | const int stride_h, 18 | const int stride_w, 19 | const int pad_h, 20 | const int pad_w, 21 | const int dilation_h, 22 | const int dilation_w, 23 | const int deformable_group) 24 | { 25 | if (input.type().is_cuda()) 26 | { 27 | #ifdef WITH_CUDA 28 | return dcn_v2_cuda_forward(input, weight, bias, offset, mask, 29 | kernel_h, kernel_w, 30 | stride_h, stride_w, 31 | pad_h, pad_w, 32 | dilation_h, dilation_w, 33 | deformable_group); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | else{ 39 | return dcn_v2_cpu_forward(input, weight, bias, offset, mask, 40 | kernel_h, kernel_w, 41 | stride_h, stride_w, 42 | pad_h, pad_w, 43 | dilation_h, dilation_w, 44 | deformable_group); 45 | } 46 | } 47 | 48 | std::vector 49 | dcn_v2_backward(const at::Tensor &input, 50 | const at::Tensor &weight, 51 | const at::Tensor &bias, 52 | const at::Tensor &offset, 53 | const at::Tensor &mask, 54 | const at::Tensor &grad_output, 55 | int kernel_h, int kernel_w, 56 | int stride_h, int stride_w, 57 | int pad_h, int pad_w, 58 | int dilation_h, int dilation_w, 59 | int deformable_group) 60 | { 61 | if (input.type().is_cuda()) 62 | { 63 | #ifdef WITH_CUDA 64 | return dcn_v2_cuda_backward(input, 65 | weight, 66 | bias, 67 | offset, 68 | mask, 69 | grad_output, 70 | kernel_h, kernel_w, 71 | stride_h, stride_w, 72 | pad_h, pad_w, 73 | dilation_h, dilation_w, 74 | deformable_group); 75 | #else 76 | AT_ERROR("Not compiled with GPU support"); 77 | #endif 78 | } 79 | else{ 80 | return dcn_v2_cpu_backward(input, 81 | weight, 82 | bias, 83 | offset, 84 | mask, 85 | grad_output, 86 | kernel_h, kernel_w, 87 | stride_h, stride_w, 88 | pad_h, pad_w, 89 | dilation_h, dilation_w, 90 | deformable_group); 91 | } 92 | } 93 | 94 | std::tuple 95 | dcn_v2_psroi_pooling_forward(const at::Tensor &input, 96 | const at::Tensor &bbox, 97 | const at::Tensor &trans, 98 | const int no_trans, 99 | const float spatial_scale, 100 | const int output_dim, 101 | const int group_size, 102 | const int pooled_size, 103 | const int part_size, 104 | const int sample_per_part, 105 | const float trans_std) 106 | { 107 | if (input.type().is_cuda()) 108 | { 109 | #ifdef WITH_CUDA 110 | return dcn_v2_psroi_pooling_cuda_forward(input, 111 | bbox, 112 | trans, 113 | no_trans, 114 | spatial_scale, 115 | output_dim, 116 | group_size, 117 | pooled_size, 118 | part_size, 119 | sample_per_part, 120 | trans_std); 121 | #else 122 | AT_ERROR("Not compiled with GPU support"); 123 | #endif 124 | } 125 | else{ 126 | return dcn_v2_psroi_pooling_cpu_forward(input, 127 | bbox, 128 | trans, 129 | no_trans, 130 | spatial_scale, 131 | output_dim, 132 | group_size, 133 | pooled_size, 134 | part_size, 135 | sample_per_part, 136 | trans_std); 137 | } 138 | } 139 | 140 | std::tuple 141 | dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad, 142 | const at::Tensor &input, 143 | const at::Tensor &bbox, 144 | const at::Tensor &trans, 145 | const at::Tensor &top_count, 146 | const int no_trans, 147 | const float spatial_scale, 148 | const int output_dim, 149 | const int group_size, 150 | const int pooled_size, 151 | const int part_size, 152 | const int sample_per_part, 153 | const float trans_std) 154 | { 155 | if (input.type().is_cuda()) 156 | { 157 | #ifdef WITH_CUDA 158 | return dcn_v2_psroi_pooling_cuda_backward(out_grad, 159 | input, 160 | bbox, 161 | trans, 162 | top_count, 163 | no_trans, 164 | spatial_scale, 165 | output_dim, 166 | group_size, 167 | pooled_size, 168 | part_size, 169 | sample_per_part, 170 | trans_std); 171 | #else 172 | AT_ERROR("Not compiled with GPU support"); 173 | #endif 174 | } 175 | else{ 176 | return dcn_v2_psroi_pooling_cpu_backward(out_grad, 177 | input, 178 | bbox, 179 | trans, 180 | top_count, 181 | no_trans, 182 | spatial_scale, 183 | output_dim, 184 | group_size, 185 | pooled_size, 186 | part_size, 187 | sample_per_part, 188 | trans_std); 189 | } 190 | } -------------------------------------------------------------------------------- /model/DCNv2/src/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "dcn_v2.h" 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward"); 6 | m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward"); 7 | m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward"); 8 | m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward"); 9 | } 10 | -------------------------------------------------------------------------------- /model/DCNv2/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import gradcheck 10 | 11 | from dcn_v2 import dcn_v2_conv, DCNv2, DCN 12 | from dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling 13 | 14 | deformable_groups = 1 15 | N, inC, inH, inW = 2, 2, 4, 4 16 | outC = 2 17 | kH, kW = 3, 3 18 | 19 | 20 | def conv_identify(weight, bias): 21 | weight.data.zero_() 22 | bias.data.zero_() 23 | o, i, h, w = weight.shape 24 | y = h//2 25 | x = w//2 26 | for p in range(i): 27 | for q in range(o): 28 | if p == q: 29 | weight.data[q, p, y, x] = 1.0 30 | 31 | 32 | def check_zero_offset(): 33 | conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW, 34 | kernel_size=(kH, kW), 35 | stride=(1, 1), 36 | padding=(1, 1), 37 | bias=True).cuda() 38 | 39 | conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW, 40 | kernel_size=(kH, kW), 41 | stride=(1, 1), 42 | padding=(1, 1), 43 | bias=True).cuda() 44 | 45 | dcn_v2 = DCNv2(inC, outC, (kH, kW), 46 | stride=1, padding=1, dilation=1, 47 | deformable_groups=deformable_groups).cuda() 48 | 49 | conv_offset.weight.data.zero_() 50 | conv_offset.bias.data.zero_() 51 | conv_mask.weight.data.zero_() 52 | conv_mask.bias.data.zero_() 53 | conv_identify(dcn_v2.weight, dcn_v2.bias) 54 | 55 | input = torch.randn(N, inC, inH, inW).cuda() 56 | offset = conv_offset(input) 57 | mask = conv_mask(input) 58 | mask = torch.sigmoid(mask) 59 | output = dcn_v2(input, offset, mask) 60 | output *= 2 61 | d = (input - output).abs().max() 62 | if d < 1e-10: 63 | print('Zero offset passed') 64 | else: 65 | print('Zero offset failed') 66 | print(input) 67 | print(output) 68 | 69 | def check_gradient_dconv(): 70 | 71 | input = torch.rand(N, inC, inH, inW).cuda() * 0.01 72 | input.requires_grad = True 73 | 74 | offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2 75 | # offset.data.zero_() 76 | # offset.data -= 0.5 77 | offset.requires_grad = True 78 | 79 | mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() 80 | # mask.data.zero_() 81 | mask.requires_grad = True 82 | mask = torch.sigmoid(mask) 83 | 84 | weight = torch.randn(outC, inC, kH, kW).cuda() 85 | weight.requires_grad = True 86 | 87 | bias = torch.rand(outC).cuda() 88 | bias.requires_grad = True 89 | 90 | stride = 1 91 | padding = 1 92 | dilation = 1 93 | 94 | print('check_gradient_dconv: ', 95 | gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias, 96 | stride, padding, dilation, deformable_groups), 97 | eps=1e-3, atol=1e-4, rtol=1e-2)) 98 | 99 | 100 | def check_pooling_zero_offset(): 101 | 102 | input = torch.randn(2, 16, 64, 64).cuda().zero_() 103 | input[0, :, 16:26, 16:26] = 1. 104 | input[1, :, 10:20, 20:30] = 2. 105 | rois = torch.tensor([ 106 | [0, 65, 65, 103, 103], 107 | [1, 81, 41, 119, 79], 108 | ]).cuda().float() 109 | pooling = DCNv2Pooling(spatial_scale=1.0 / 4, 110 | pooled_size=7, 111 | output_dim=16, 112 | no_trans=True, 113 | group_size=1, 114 | trans_std=0.0).cuda() 115 | 116 | out = pooling(input, rois, input.new()) 117 | s = ', '.join(['%f' % out[i, :, :, :].mean().item() 118 | for i in range(rois.shape[0])]) 119 | print(s) 120 | 121 | dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, 122 | pooled_size=7, 123 | output_dim=16, 124 | no_trans=False, 125 | group_size=1, 126 | trans_std=0.0).cuda() 127 | offset = torch.randn(20, 2, 7, 7).cuda().zero_() 128 | dout = dpooling(input, rois, offset) 129 | s = ', '.join(['%f' % dout[i, :, :, :].mean().item() 130 | for i in range(rois.shape[0])]) 131 | print(s) 132 | 133 | 134 | def check_gradient_dpooling(): 135 | input = torch.randn(2, 3, 5, 5).cuda() * 0.01 136 | N = 4 137 | batch_inds = torch.randint(2, (N, 1)).cuda().float() 138 | x = torch.rand((N, 1)).cuda().float() * 15 139 | y = torch.rand((N, 1)).cuda().float() * 15 140 | w = torch.rand((N, 1)).cuda().float() * 10 141 | h = torch.rand((N, 1)).cuda().float() * 10 142 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 143 | offset = torch.randn(N, 2, 3, 3).cuda() 144 | input.requires_grad = True 145 | offset.requires_grad = True 146 | 147 | spatial_scale = 1.0 / 4 148 | pooled_size = 3 149 | output_dim = 3 150 | no_trans = 0 151 | group_size = 1 152 | trans_std = 0.0 153 | sample_per_part = 4 154 | part_size = pooled_size 155 | 156 | print('check_gradient_dpooling:', 157 | gradcheck(dcn_v2_pooling, (input, rois, offset, 158 | spatial_scale, 159 | pooled_size, 160 | output_dim, 161 | no_trans, 162 | group_size, 163 | part_size, 164 | sample_per_part, 165 | trans_std), 166 | eps=1e-4)) 167 | 168 | 169 | def example_dconv(): 170 | input = torch.randn(2, 64, 128, 128).cuda() 171 | # wrap all things (offset and mask) in DCN 172 | dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, 173 | padding=1, deformable_groups=2).cuda() 174 | # print(dcn.weight.shape, input.shape) 175 | output = dcn(input) 176 | targert = output.new(*output.size()) 177 | targert.data.uniform_(-0.01, 0.01) 178 | error = (targert - output).mean() 179 | error.backward() 180 | print(output.shape) 181 | 182 | 183 | def example_dpooling(): 184 | input = torch.randn(2, 32, 64, 64).cuda() 185 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 186 | x = torch.randint(256, (20, 1)).cuda().float() 187 | y = torch.randint(256, (20, 1)).cuda().float() 188 | w = torch.randint(64, (20, 1)).cuda().float() 189 | h = torch.randint(64, (20, 1)).cuda().float() 190 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 191 | offset = torch.randn(20, 2, 7, 7).cuda() 192 | input.requires_grad = True 193 | offset.requires_grad = True 194 | 195 | # normal roi_align 196 | pooling = DCNv2Pooling(spatial_scale=1.0 / 4, 197 | pooled_size=7, 198 | output_dim=32, 199 | no_trans=True, 200 | group_size=1, 201 | trans_std=0.1).cuda() 202 | 203 | # deformable pooling 204 | dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, 205 | pooled_size=7, 206 | output_dim=32, 207 | no_trans=False, 208 | group_size=1, 209 | trans_std=0.1).cuda() 210 | 211 | out = pooling(input, rois, offset) 212 | dout = dpooling(input, rois, offset) 213 | print(out.shape) 214 | print(dout.shape) 215 | 216 | target_out = out.new(*out.size()) 217 | target_out.data.uniform_(-0.01, 0.01) 218 | target_dout = dout.new(*dout.size()) 219 | target_dout.data.uniform_(-0.01, 0.01) 220 | e = (target_out - out).mean() 221 | e.backward() 222 | e = (target_dout - dout).mean() 223 | e.backward() 224 | 225 | 226 | def example_mdpooling(): 227 | input = torch.randn(2, 32, 64, 64).cuda() 228 | input.requires_grad = True 229 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 230 | x = torch.randint(256, (20, 1)).cuda().float() 231 | y = torch.randint(256, (20, 1)).cuda().float() 232 | w = torch.randint(64, (20, 1)).cuda().float() 233 | h = torch.randint(64, (20, 1)).cuda().float() 234 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 235 | 236 | # mdformable pooling (V2) 237 | dpooling = DCNPooling(spatial_scale=1.0 / 4, 238 | pooled_size=7, 239 | output_dim=32, 240 | no_trans=False, 241 | group_size=1, 242 | trans_std=0.1, 243 | deform_fc_dim=1024).cuda() 244 | 245 | dout = dpooling(input, rois) 246 | target = dout.new(*dout.size()) 247 | target.data.uniform_(-0.1, 0.1) 248 | error = (target - dout).mean() 249 | error.backward() 250 | print(dout.shape) 251 | 252 | 253 | if __name__ == '__main__': 254 | 255 | example_dconv() 256 | example_dpooling() 257 | example_mdpooling() 258 | 259 | check_pooling_zero_offset() 260 | # zero offset check 261 | if inC == outC: 262 | check_zero_offset() 263 | 264 | check_gradient_dpooling() 265 | check_gradient_dconv() 266 | # """ 267 | # ****** Note: backward is not reentrant error may not be a serious problem, 268 | # ****** since the max error is less than 1e-7, 269 | # ****** Still looking for what trigger this problem 270 | # """ 271 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel as P 7 | import torch.utils.model_zoo 8 | 9 | class Model(nn.Module): 10 | def __init__(self, args, ckp): 11 | super(Model, self).__init__() 12 | self.args = args 13 | if args.local_rank == 0: 14 | print('Making model...') 15 | 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.input_large = (args.model == 'VDSR') 19 | self.self_ensemble = args.self_ensemble 20 | self.chop = args.chop 21 | self.precision = args.precision 22 | self.cpu = args.cpu 23 | self.device = torch.device('cpu' if args.cpu else 'cuda:%d' % args.local_rank) 24 | self.n_GPUs = args.n_GPUs 25 | self.save_models = args.save_models 26 | 27 | module = import_module('model.' + args.model.lower()) 28 | self.model = module.make_model(args).to(self.device) 29 | 30 | if args.precision == 'half': 31 | self.model.half() 32 | 33 | self.load( 34 | ckp.get_path('model'), 35 | pre_train=args.pre_train, 36 | resume=args.resume, 37 | cpu=args.cpu 38 | ) 39 | 40 | if args.n_GPUs > 1: 41 | self.model = nn.parallel.DistributedDataParallel(self.model, 42 | device_ids=[args.local_rank], 43 | find_unused_parameters=True 44 | ) 45 | 46 | print(self.model, file=ckp.log_file) 47 | 48 | def forward(self, x, idx_scale): 49 | self.idx_scale = idx_scale 50 | if hasattr(self.model, 'set_scale'): 51 | self.model.set_scale(idx_scale) 52 | 53 | if self.training: 54 | # if self.n_GPUs > 1: 55 | return self.model(x) 56 | else: 57 | if self.chop: 58 | forward_function = self.forward_chop 59 | else: 60 | forward_function = self.model.forward 61 | 62 | if self.self_ensemble: 63 | return self.forward_x8(x, forward_function=forward_function) 64 | else: 65 | # return self.model(x) 66 | return forward_function(x) 67 | 68 | def save(self, apath, epoch, is_best=False): 69 | save_dirs = [os.path.join(apath, 'model_latest.pt')] 70 | 71 | if is_best: 72 | save_dirs.append(os.path.join(apath, 'model_best.pt')) 73 | if self.save_models: 74 | save_dirs.append( 75 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 76 | ) 77 | if self.n_GPUs > 1: 78 | model = self.model.module 79 | else: 80 | model = self.model 81 | 82 | for s in save_dirs: 83 | torch.save(self.model.state_dict(), s) 84 | 85 | def load(self, apath, pre_train='', resume=-1, cpu=False): 86 | load_from = None 87 | kwargs = {} 88 | if cpu: 89 | kwargs = {'map_location': lambda storage, loc: storage} 90 | 91 | if resume == -1: 92 | load_from = torch.load( 93 | os.path.join(apath, 'model_latest.pt'), 94 | **kwargs 95 | ) 96 | elif resume == 0: 97 | if pre_train == 'download': 98 | print('Download the model') 99 | dir_model = os.path.join('..', 'models') 100 | os.makedirs(dir_model, exist_ok=True) 101 | load_from = torch.utils.model_zoo.load_url( 102 | self.model.url, 103 | model_dir=dir_model, 104 | **kwargs 105 | ) 106 | elif pre_train: 107 | print('Load the model from {}'.format(pre_train)) 108 | map_location = {'cuda:%d' % 0: 'cuda:%d' % self.args.local_rank} 109 | load_from = torch.load(pre_train, map_location=map_location) 110 | else: 111 | load_from = torch.load( 112 | os.path.join(apath, 'model_{}.pt'.format(resume)), 113 | **kwargs 114 | ) 115 | 116 | if load_from: 117 | self.model.load_state_dict(load_from) 118 | del load_from 119 | 120 | def forward_chop(self, *args, shave=10, min_size=160000): 121 | scale = 1 if self.input_large else self.scale[self.idx_scale] 122 | n_GPUs = min(self.n_GPUs, 4) 123 | # height, width 124 | h, w = args[0].size()[-2:] 125 | 126 | top = slice(0, h//2 + shave) 127 | bottom = slice(h - h//2 - shave, h) 128 | left = slice(0, w//2 + shave) 129 | right = slice(w - w//2 - shave, w) 130 | x_chops = [torch.cat([ 131 | a[..., top, left], 132 | a[..., top, right], 133 | a[..., bottom, left], 134 | a[..., bottom, right] 135 | ]) for a in args] 136 | 137 | y_chops = [] 138 | if h * w < 4 * min_size: 139 | for i in range(0, 4, n_GPUs): 140 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] 141 | y = P.data_parallel(self.model, *x, range(n_GPUs)) 142 | if not isinstance(y, list): y = [y] 143 | if not y_chops: 144 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] 145 | else: 146 | for y_chop, _y in zip(y_chops, y): 147 | y_chop.extend(_y.chunk(n_GPUs, dim=0)) 148 | else: 149 | for p in zip(*x_chops): 150 | y = self.forward_chop(*p, shave=shave, min_size=min_size) 151 | if not isinstance(y, list): y = [y] 152 | if not y_chops: 153 | y_chops = [[_y] for _y in y] 154 | else: 155 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y) 156 | 157 | h *= scale 158 | w *= scale 159 | top = slice(0, h//2) 160 | bottom = slice(h - h//2, h) 161 | bottom_r = slice(h//2 - h, None) 162 | left = slice(0, w//2) 163 | right = slice(w - w//2, w) 164 | right_r = slice(w//2 - w, None) 165 | 166 | # batch size, number of color channels 167 | b, c = y_chops[0][0].size()[:-2] 168 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] 169 | for y_chop, _y in zip(y_chops, y): 170 | _y[..., top, left] = y_chop[0][..., top, left] 171 | _y[..., top, right] = y_chop[1][..., top, right_r] 172 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left] 173 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] 174 | 175 | if len(y) == 1: y = y[0] 176 | 177 | return y 178 | 179 | def forward_x8(self, *args, forward_function=None): 180 | def _transform(v, op): 181 | if self.precision != 'single': v = v.float() 182 | 183 | v2np = v.data.cpu().numpy() 184 | if op == 'v': 185 | tfnp = v2np[:, :, :, ::-1].copy() 186 | elif op == 'h': 187 | tfnp = v2np[:, :, ::-1, :].copy() 188 | elif op == 't': 189 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 190 | 191 | ret = torch.Tensor(tfnp).to(self.device) 192 | if self.precision == 'half': ret = ret.half() 193 | 194 | return ret 195 | 196 | list_x = [] 197 | for a in args: 198 | x = [a] 199 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) 200 | 201 | list_x.append(x) 202 | 203 | list_y = [] 204 | for x in zip(*list_x): 205 | y = forward_function(*x) 206 | if not isinstance(y, list): y = [y] 207 | if not list_y: 208 | list_y = [[_y] for _y in y] 209 | else: 210 | for _list_y, _y in zip(list_y, y): _list_y.append(_y) 211 | 212 | for _list_y in list_y: 213 | for i in range(len(_list_y)): 214 | if i > 3: 215 | _list_y[i] = _transform(_list_y[i], 't') 216 | if i % 4 > 1: 217 | _list_y[i] = _transform(_list_y[i], 'h') 218 | if (i % 4) % 2 == 1: 219 | _list_y[i] = _transform(_list_y[i], 'v') 220 | 221 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] 222 | if len(y) == 1: y = y[0] 223 | 224 | return y 225 | -------------------------------------------------------------------------------- /model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 10 | return nn.Conv2d( 11 | in_channels, out_channels, kernel_size, 12 | padding=(kernel_size // 2), bias=bias) 13 | 14 | 15 | class MeanShift(nn.Conv2d): 16 | def __init__( 17 | self, rgb_range, 18 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 19 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 20 | std = torch.Tensor(rgb_std) 21 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 22 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 23 | for p in self.parameters(): 24 | p.requires_grad = False 25 | 26 | 27 | class BasicBlock(nn.Sequential): 28 | def __init__( 29 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 30 | bn=True, act=nn.ReLU(True)): 31 | 32 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 33 | if bn: 34 | m.append(nn.BatchNorm2d(out_channels)) 35 | if act is not None: 36 | m.append(act) 37 | 38 | super(BasicBlock, self).__init__(*m) 39 | 40 | 41 | class ResBlock(nn.Module): 42 | def __init__( 43 | self, conv, n_feats, kernel_size, 44 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 45 | 46 | super(ResBlock, self).__init__() 47 | m = [] 48 | for i in range(2): 49 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 50 | if bn: 51 | m.append(nn.BatchNorm2d(n_feats)) 52 | if i == 0: 53 | m.append(act) 54 | 55 | self.body = nn.Sequential(*m) 56 | self.res_scale = res_scale 57 | 58 | def forward(self, x): 59 | res = self.body(x).mul(self.res_scale) 60 | res += x 61 | 62 | return res 63 | 64 | 65 | class Upsampler(nn.Sequential): 66 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 67 | 68 | m = [] 69 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 70 | for _ in range(int(math.log(scale, 2))): 71 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 72 | m.append(nn.PixelShuffle(2)) 73 | if bn: 74 | m.append(nn.BatchNorm2d(n_feats)) 75 | if act == 'relu': 76 | m.append(nn.ReLU(True)) 77 | elif act == 'prelu': 78 | m.append(nn.PReLU(n_feats)) 79 | 80 | elif scale == 3: 81 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 82 | m.append(nn.PixelShuffle(3)) 83 | if bn: 84 | m.append(nn.BatchNorm2d(n_feats)) 85 | if act == 'relu': 86 | m.append(nn.ReLU(True)) 87 | elif act == 'prelu': 88 | m.append(nn.PReLU(n_feats)) 89 | else: 90 | raise NotImplementedError 91 | 92 | super(Upsampler, self).__init__(*m) 93 | 94 | 95 | class UpOnly(nn.Sequential): 96 | def __init__(self, scale): 97 | 98 | m = [] 99 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 100 | for _ in range(int(math.log(scale, 2))): 101 | m.append(nn.PixelShuffle(2)) 102 | 103 | 104 | elif scale == 3: 105 | 106 | m.append(nn.PixelShuffle(3)) 107 | 108 | else: 109 | raise NotImplementedError 110 | 111 | super(UpOnly, self).__init__(*m) 112 | 113 | 114 | def lanczos_kernel(dx, a=3, N=None, dtype=None, device=None): 115 | ''' 116 | Generates 1D Lanczos kernels for translation and interpolation. 117 | Args: 118 | dx : float, tensor (batch_size, 1), the translation in pixels to shift an image. 119 | a : int, number of lobes in the kernel support. 120 | If N is None, then the width is the kernel support (length of all lobes), 121 | S = 2(a + ceil(dx)) + 1. 122 | N : int, width of the kernel. 123 | If smaller than S then N is set to S. 124 | Returns: 125 | k: tensor (?, ?), lanczos kernel 126 | ''' 127 | 128 | if not torch.is_tensor(dx): 129 | dx = torch.tensor(dx, dtype=dtype, device=device) 130 | 131 | if device is None: 132 | device = dx.device 133 | 134 | if dtype is None: 135 | dtype = dx.dtype 136 | 137 | D = dx.abs().ceil().int() 138 | S = 2 * (a + D) + 1 # width of kernel support 139 | 140 | S_max = S.max() if hasattr(S, 'shape') else S 141 | 142 | if (N is None) or (N < S_max): 143 | N = S 144 | 145 | Z = (N - S) // 2 # width of zeros beyond kernel support 146 | 147 | start = (-(a + D + Z)).min() 148 | end = (a + D + Z + 1).max() 149 | x = torch.arange(start, end, dtype=dtype, device=device).view(1, -1) - dx 150 | px = (np.pi * x) + 1e-3 151 | 152 | sin_px = torch.sin(px) 153 | sin_pxa = torch.sin(px / a) 154 | 155 | k = a * sin_px * sin_pxa / px ** 2 # sinc(x) masked by sinc(x/a) 156 | 157 | return k 158 | 159 | 160 | def lanczos_shift(img, shift, p=5, a=3): 161 | ''' 162 | Shifts an image by convolving it with a Lanczos kernel. 163 | Lanczos interpolation is an approximation to ideal sinc interpolation, 164 | by windowing a sinc kernel with another sinc function extending up to a 165 | few nunber of its lobes (typically a=3). 166 | 167 | Args: 168 | img : tensor (batch_size, channels, height, width), the images to be shifted 169 | shift : tensor (batch_size, 2) of translation parameters (dy, dx) 170 | p : int, padding width prior to convolution (default=3) 171 | a : int, number of lobes in the Lanczos interpolation kernel (default=3) 172 | Returns: 173 | I_s: tensor (batch_size, channels, height, width), shifted images 174 | ''' 175 | img = img.transpose(0, 1) 176 | dtype = img.dtype 177 | 178 | if len(img.shape) == 2: 179 | img = img[None, None].repeat(1, shift.shape[0], 1, 1) # batch of one image 180 | elif len(img.shape) == 3: # one image per shift 181 | assert img.shape[0] == shift.shape[0] 182 | img = img[None,] 183 | 184 | # Apply padding 185 | 186 | padder = torch.nn.ReflectionPad2d(p) # reflect pre-padding 187 | I_padded = padder(img) 188 | 189 | # Create 1D shifting kernels 190 | 191 | y_shift = shift[:, [0]] 192 | x_shift = shift[:, [1]] 193 | 194 | k_y = (lanczos_kernel(y_shift, a=a, N=None, dtype=dtype) 195 | .flip(1) # flip axis of convolution 196 | )[:, None, :, None] # expand dims to get shape (batch, channels, y_kernel, 1) 197 | k_x = (lanczos_kernel(x_shift, a=a, N=None, dtype=dtype) 198 | .flip(1) 199 | )[:, None, None, :] # shape (batch, channels, 1, x_kernel) 200 | 201 | # Apply kernels 202 | # print(I_padded.shape, k_y.shape) 203 | I_s = torch.conv1d(I_padded, 204 | groups=k_y.shape[0], 205 | weight=k_y, 206 | padding=[k_y.shape[2] // 2, 0]) # same padding 207 | I_s = torch.conv1d(I_s, 208 | groups=k_x.shape[0], 209 | weight=k_x, 210 | padding=[0, k_x.shape[3] // 2]) 211 | 212 | I_s = I_s[..., p:-p, p:-p] # remove padding 213 | 214 | # print(I_s.shape) 215 | return I_s.transpose(0, 1) # , k.squeeze() 216 | -------------------------------------------------------------------------------- /model/non_local/network.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | # from lib.non_local_concatenation import NONLocalBlock2D 3 | # from lib.non_local_gaussian import NONLocalBlock2D 4 | from lib.non_local_embedded_gaussian import NONLocalBlock2D 5 | # from lib.non_local_dot_product import NONLocalBlock2D 6 | 7 | 8 | class Network(nn.Module): 9 | def __init__(self): 10 | super(Network, self).__init__() 11 | 12 | self.conv_1 = nn.Sequential( 13 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), 14 | nn.BatchNorm2d(32), 15 | nn.ReLU(), 16 | nn.MaxPool2d(2), 17 | ) 18 | 19 | self.nl_1 = NONLocalBlock2D(in_channels=32) 20 | self.conv_2 = nn.Sequential( 21 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), 22 | nn.BatchNorm2d(64), 23 | nn.ReLU(), 24 | nn.MaxPool2d(2), 25 | ) 26 | 27 | self.nl_2 = NONLocalBlock2D(in_channels=64) 28 | self.conv_3 = nn.Sequential( 29 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 30 | nn.BatchNorm2d(128), 31 | nn.ReLU(), 32 | nn.MaxPool2d(2), 33 | ) 34 | 35 | self.fc = nn.Sequential( 36 | nn.Linear(in_features=128*3*3, out_features=256), 37 | nn.ReLU(), 38 | nn.Dropout(0.5), 39 | 40 | nn.Linear(in_features=256, out_features=10) 41 | ) 42 | 43 | def forward(self, x): 44 | batch_size = x.size(0) 45 | 46 | feature_1 = self.conv_1(x) 47 | nl_feature_1 = self.nl_1(feature_1) 48 | 49 | feature_2 = self.conv_2(nl_feature_1) 50 | nl_feature_2 = self.nl_2(feature_2) 51 | 52 | output = self.conv_3(nl_feature_2).view(batch_size, -1) 53 | output = self.fc(output) 54 | 55 | return output 56 | 57 | def forward_with_nl_map(self, x): 58 | batch_size = x.size(0) 59 | 60 | feature_1 = self.conv_1(x) 61 | nl_feature_1, nl_map_1 = self.nl_1(feature_1, return_nl_map=True) 62 | 63 | feature_2 = self.conv_2(nl_feature_1) 64 | nl_feature_2, nl_map_2 = self.nl_2(feature_2, return_nl_map=True) 65 | 66 | output = self.conv_3(nl_feature_2).view(batch_size, -1) 67 | output = self.fc(output) 68 | 69 | return output, [nl_map_1, nl_map_2] 70 | 71 | 72 | if __name__ == '__main__': 73 | import torch 74 | 75 | img = torch.randn(3, 1, 28, 28) 76 | net = Network() 77 | out = net(img) 78 | print(out.size()) 79 | 80 | -------------------------------------------------------------------------------- /model/non_local/non_local_concatenation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | 56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | 59 | self.concat_project = nn.Sequential( 60 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), 61 | nn.ReLU() 62 | ) 63 | 64 | if sub_sample: 65 | self.g = nn.Sequential(self.g, max_pool_layer) 66 | self.phi = nn.Sequential(self.phi, max_pool_layer) 67 | 68 | def forward(self, x, return_nl_map=False): 69 | ''' 70 | :param x: (b, c, t, h, w) 71 | :param return_nl_map: if True return z, nl_map, else only return z. 72 | :return: 73 | ''' 74 | 75 | batch_size = x.size(0) 76 | 77 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 78 | g_x = g_x.permute(0, 2, 1) 79 | 80 | # (b, c, N, 1) 81 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 82 | # (b, c, 1, N) 83 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 84 | 85 | h = theta_x.size(2) 86 | w = phi_x.size(3) 87 | theta_x = theta_x.repeat(1, 1, 1, w) 88 | phi_x = phi_x.repeat(1, 1, h, 1) 89 | 90 | concat_feature = torch.cat([theta_x, phi_x], dim=1) 91 | f = self.concat_project(concat_feature) 92 | b, _, h, w = f.size() 93 | f = f.view(b, h, w) 94 | 95 | N = f.size(-1) 96 | f_div_C = f / N 97 | 98 | y = torch.matmul(f_div_C, g_x) 99 | y = y.permute(0, 2, 1).contiguous() 100 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 101 | W_y = self.W(y) 102 | z = W_y + x 103 | 104 | if return_nl_map: 105 | return z, f_div_C 106 | return z 107 | 108 | 109 | class NONLocalBlock1D(_NonLocalBlockND): 110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 111 | super(NONLocalBlock1D, self).__init__(in_channels, 112 | inter_channels=inter_channels, 113 | dimension=1, sub_sample=sub_sample, 114 | bn_layer=bn_layer) 115 | 116 | 117 | class NONLocalBlock2D(_NonLocalBlockND): 118 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 119 | super(NONLocalBlock2D, self).__init__(in_channels, 120 | inter_channels=inter_channels, 121 | dimension=2, sub_sample=sub_sample, 122 | bn_layer=bn_layer) 123 | 124 | 125 | class NONLocalBlock3D(_NonLocalBlockND): 126 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,): 127 | super(NONLocalBlock3D, self).__init__(in_channels, 128 | inter_channels=inter_channels, 129 | dimension=3, sub_sample=sub_sample, 130 | bn_layer=bn_layer) 131 | 132 | 133 | if __name__ == '__main__': 134 | import torch 135 | 136 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 137 | img = torch.zeros(2, 3, 20) 138 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 139 | out = net(img) 140 | print(out.size()) 141 | 142 | img = torch.zeros(2, 3, 20, 20) 143 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 144 | out = net(img) 145 | print(out.size()) 146 | 147 | img = torch.randn(2, 3, 8, 20, 20) 148 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 149 | out = net(img) 150 | print(out.size()) 151 | -------------------------------------------------------------------------------- /model/non_local/non_local_cross_dot_product.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(4)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | 56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | 59 | if sub_sample: 60 | self.g = nn.Sequential(self.g, max_pool_layer) 61 | self.phi = nn.Sequential(self.phi, max_pool_layer) 62 | 63 | def forward(self, x, ref, return_nl_map=False): 64 | """ 65 | :param x: (b, c, t, h, w) 66 | :param return_nl_map: if True return z, nl_map, else only return z. 67 | :return: 68 | """ 69 | 70 | batch_size = x.size(0) 71 | 72 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 73 | g_x = g_x.permute(0, 2, 1) 74 | 75 | theta_ref = self.theta(ref).view(batch_size, self.inter_channels, -1) 76 | theta_ref = theta_ref.permute(0, 2, 1) 77 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 78 | f = torch.matmul(theta_ref, phi_x) 79 | N = f.size(-1) 80 | f_div_C = f / N 81 | 82 | y = torch.matmul(f_div_C, g_x) 83 | y = y.permute(0, 2, 1).contiguous() 84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 85 | W_y = self.W(y) 86 | z = W_y + x 87 | 88 | if return_nl_map: 89 | return z, f_div_C 90 | return z 91 | 92 | 93 | class NONLocalBlock1D(_NonLocalBlockND): 94 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 95 | super(NONLocalBlock1D, self).__init__(in_channels, 96 | inter_channels=inter_channels, 97 | dimension=1, sub_sample=sub_sample, 98 | bn_layer=bn_layer) 99 | 100 | 101 | class NONLocalBlock2D(_NonLocalBlockND): 102 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 103 | super(NONLocalBlock2D, self).__init__(in_channels, 104 | inter_channels=inter_channels, 105 | dimension=2, sub_sample=sub_sample, 106 | bn_layer=bn_layer) 107 | 108 | 109 | class NONLocalBlock3D(_NonLocalBlockND): 110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 111 | super(NONLocalBlock3D, self).__init__(in_channels, 112 | inter_channels=inter_channels, 113 | dimension=3, sub_sample=sub_sample, 114 | bn_layer=bn_layer) 115 | 116 | 117 | if __name__ == '__main__': 118 | import torch 119 | 120 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 121 | img = torch.zeros(2, 3, 20) 122 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 123 | out = net(img) 124 | print(out.size()) 125 | 126 | img = torch.zeros(2, 3, 20, 20) 127 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 128 | out = net(img) 129 | print(out.size()) 130 | 131 | img = torch.randn(2, 3, 8, 20, 20) 132 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 133 | out = net(img) 134 | print(out.size()) 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /model/non_local/non_local_dot_product.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | 56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | 59 | if sub_sample: 60 | self.g = nn.Sequential(self.g, max_pool_layer) 61 | self.phi = nn.Sequential(self.phi, max_pool_layer) 62 | 63 | def forward(self, x, return_nl_map=False): 64 | """ 65 | :param x: (b, c, t, h, w) 66 | :param return_nl_map: if True return z, nl_map, else only return z. 67 | :return: 68 | """ 69 | 70 | batch_size = x.size(0) 71 | 72 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 73 | g_x = g_x.permute(0, 2, 1) 74 | 75 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 76 | theta_x = theta_x.permute(0, 2, 1) 77 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 78 | f = torch.matmul(theta_x, phi_x) 79 | N = f.size(-1) 80 | f_div_C = f / N 81 | 82 | y = torch.matmul(f_div_C, g_x) 83 | y = y.permute(0, 2, 1).contiguous() 84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 85 | W_y = self.W(y) 86 | z = W_y + x 87 | 88 | if return_nl_map: 89 | return z, f_div_C 90 | return z 91 | 92 | 93 | class NONLocalBlock1D(_NonLocalBlockND): 94 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 95 | super(NONLocalBlock1D, self).__init__(in_channels, 96 | inter_channels=inter_channels, 97 | dimension=1, sub_sample=sub_sample, 98 | bn_layer=bn_layer) 99 | 100 | 101 | class NONLocalBlock2D(_NonLocalBlockND): 102 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 103 | super(NONLocalBlock2D, self).__init__(in_channels, 104 | inter_channels=inter_channels, 105 | dimension=2, sub_sample=sub_sample, 106 | bn_layer=bn_layer) 107 | 108 | 109 | class NONLocalBlock3D(_NonLocalBlockND): 110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 111 | super(NONLocalBlock3D, self).__init__(in_channels, 112 | inter_channels=inter_channels, 113 | dimension=3, sub_sample=sub_sample, 114 | bn_layer=bn_layer) 115 | 116 | 117 | if __name__ == '__main__': 118 | import torch 119 | 120 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 121 | img = torch.zeros(2, 3, 20) 122 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 123 | out = net(img) 124 | print(out.size()) 125 | 126 | img = torch.zeros(2, 3, 20, 20) 127 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 128 | out = net(img) 129 | print(out.size()) 130 | 131 | img = torch.randn(2, 3, 8, 20, 20) 132 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 133 | out = net(img) 134 | print(out.size()) 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /model/non_local/non_local_embedded_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | """ 9 | :param in_channels: 10 | :param inter_channels: 11 | :param dimension: 12 | :param sub_sample: 13 | :param bn_layer: 14 | """ 15 | 16 | super(_NonLocalBlockND, self).__init__() 17 | 18 | assert dimension in [1, 2, 3] 19 | 20 | self.dimension = dimension 21 | self.sub_sample = sub_sample 22 | 23 | self.in_channels = in_channels 24 | self.inter_channels = inter_channels 25 | 26 | if self.inter_channels is None: 27 | self.inter_channels = in_channels // 2 28 | if self.inter_channels == 0: 29 | self.inter_channels = 1 30 | 31 | if dimension == 3: 32 | conv_nd = nn.Conv3d 33 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 34 | bn = nn.BatchNorm3d 35 | elif dimension == 2: 36 | conv_nd = nn.Conv2d 37 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 38 | bn = nn.BatchNorm2d 39 | else: 40 | conv_nd = nn.Conv1d 41 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 42 | bn = nn.BatchNorm1d 43 | 44 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 45 | kernel_size=1, stride=1, padding=0) 46 | 47 | if bn_layer: 48 | self.W = nn.Sequential( 49 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 50 | kernel_size=1, stride=1, padding=0), 51 | bn(self.in_channels) 52 | ) 53 | nn.init.constant_(self.W[1].weight, 0) 54 | nn.init.constant_(self.W[1].bias, 0) 55 | else: 56 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | nn.init.constant_(self.W.weight, 0) 59 | nn.init.constant_(self.W.bias, 0) 60 | 61 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 62 | kernel_size=1, stride=1, padding=0) 63 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 64 | kernel_size=1, stride=1, padding=0) 65 | 66 | if sub_sample: 67 | self.g = nn.Sequential(self.g, max_pool_layer) 68 | self.phi = nn.Sequential(self.phi, max_pool_layer) 69 | 70 | def forward(self, x, return_nl_map=False): 71 | """ 72 | :param x: (b, c, t, h, w) 73 | :param return_nl_map: if True return z, nl_map, else only return z. 74 | :return: 75 | """ 76 | 77 | batch_size = x.size(0) 78 | 79 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 80 | g_x = g_x.permute(0, 2, 1) 81 | 82 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 83 | theta_x = theta_x.permute(0, 2, 1) 84 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 85 | f = torch.matmul(theta_x, phi_x) 86 | f_div_C = F.softmax(f, dim=-1) 87 | 88 | y = torch.matmul(f_div_C, g_x) 89 | y = y.permute(0, 2, 1).contiguous() 90 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 91 | W_y = self.W(y) 92 | z = W_y + x 93 | 94 | if return_nl_map: 95 | return z, f_div_C 96 | return z 97 | 98 | 99 | class NONLocalBlock1D(_NonLocalBlockND): 100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 101 | super(NONLocalBlock1D, self).__init__(in_channels, 102 | inter_channels=inter_channels, 103 | dimension=1, sub_sample=sub_sample, 104 | bn_layer=bn_layer) 105 | 106 | 107 | class NONLocalBlock2D(_NonLocalBlockND): 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 109 | super(NONLocalBlock2D, self).__init__(in_channels, 110 | inter_channels=inter_channels, 111 | dimension=2, sub_sample=sub_sample, 112 | bn_layer=bn_layer,) 113 | 114 | 115 | class NONLocalBlock3D(_NonLocalBlockND): 116 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 117 | super(NONLocalBlock3D, self).__init__(in_channels, 118 | inter_channels=inter_channels, 119 | dimension=3, sub_sample=sub_sample, 120 | bn_layer=bn_layer,) 121 | 122 | 123 | if __name__ == '__main__': 124 | import torch 125 | 126 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 127 | img = torch.zeros(2, 3, 20) 128 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 129 | out = net(img) 130 | print(out.size()) 131 | 132 | img = torch.zeros(2, 3, 20, 20) 133 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 134 | out = net(img) 135 | print(out.size()) 136 | 137 | img = torch.randn(2, 3, 8, 20, 20) 138 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 139 | out = net(img) 140 | print(out.size()) 141 | 142 | 143 | -------------------------------------------------------------------------------- /model/non_local/non_local_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | if sub_sample: 54 | self.g = nn.Sequential(self.g, max_pool_layer) 55 | self.phi = max_pool_layer 56 | 57 | def forward(self, x, return_nl_map=False): 58 | """ 59 | :param x: (b, c, t, h, w) 60 | :param return_nl_map: if True return z, nl_map, else only return z. 61 | :return: 62 | """ 63 | 64 | batch_size = x.size(0) 65 | 66 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 67 | 68 | g_x = g_x.permute(0, 2, 1) 69 | 70 | theta_x = x.view(batch_size, self.in_channels, -1) 71 | theta_x = theta_x.permute(0, 2, 1) 72 | 73 | if self.sub_sample: 74 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1) 75 | else: 76 | phi_x = x.view(batch_size, self.in_channels, -1) 77 | 78 | f = torch.matmul(theta_x, phi_x) 79 | f_div_C = F.softmax(f, dim=-1) 80 | 81 | # if self.store_last_batch_nl_map: 82 | # self.nl_map = f_div_C 83 | 84 | y = torch.matmul(f_div_C, g_x) 85 | y = y.permute(0, 2, 1).contiguous() 86 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 87 | W_y = self.W(y) 88 | z = W_y + x 89 | 90 | if return_nl_map: 91 | return z, f_div_C 92 | return z 93 | 94 | 95 | class NONLocalBlock1D(_NonLocalBlockND): 96 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 97 | super(NONLocalBlock1D, self).__init__(in_channels, 98 | inter_channels=inter_channels, 99 | dimension=1, sub_sample=sub_sample, 100 | bn_layer=bn_layer) 101 | 102 | 103 | class NONLocalBlock2D(_NonLocalBlockND): 104 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 105 | super(NONLocalBlock2D, self).__init__(in_channels, 106 | inter_channels=inter_channels, 107 | dimension=2, sub_sample=sub_sample, 108 | bn_layer=bn_layer) 109 | 110 | 111 | class NONLocalBlock3D(_NonLocalBlockND): 112 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 113 | super(NONLocalBlock3D, self).__init__(in_channels, 114 | inter_channels=inter_channels, 115 | dimension=3, sub_sample=sub_sample, 116 | bn_layer=bn_layer) 117 | 118 | 119 | if __name__ == '__main__': 120 | import torch 121 | 122 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 123 | img = torch.zeros(2, 3, 20) 124 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 125 | out = net(img) 126 | print(out.size()) 127 | 128 | img = torch.zeros(2, 3, 20, 20) 129 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 130 | out = net(img) 131 | print(out.size()) 132 | 133 | img = torch.randn(2, 3, 8, 20, 20) 134 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 135 | out = net(img) 136 | print(out.size()) 137 | 138 | 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /model/utils/interp_methods.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | torch = None 7 | 8 | try: 9 | import numpy 10 | except ImportError: 11 | numpy = None 12 | 13 | if numpy is None and torch is None: 14 | raise ImportError("Must have either Numpy or PyTorch but both not found") 15 | 16 | 17 | def set_framework_dependencies(x): 18 | if type(x) is numpy.ndarray: 19 | to_dtype = lambda a: a 20 | fw = numpy 21 | else: 22 | to_dtype = lambda a: a.to(x.dtype) 23 | fw = torch 24 | eps = fw.finfo(fw.float32).eps 25 | return fw, to_dtype, eps 26 | 27 | 28 | def support_sz(sz): 29 | def wrapper(f): 30 | f.support_sz = sz 31 | return f 32 | return wrapper 33 | 34 | @support_sz(4) 35 | def cubic(x): 36 | fw, to_dtype, eps = set_framework_dependencies(x) 37 | absx = fw.abs(x) 38 | absx2 = absx ** 2 39 | absx3 = absx ** 3 40 | return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + 41 | (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * 42 | to_dtype((1. < absx) & (absx <= 2.))) 43 | 44 | @support_sz(4) 45 | def lanczos2(x): 46 | fw, to_dtype, eps = set_framework_dependencies(x) 47 | return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / 48 | ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) 49 | 50 | @support_sz(6) 51 | def lanczos3(x): 52 | fw, to_dtype, eps = set_framework_dependencies(x) 53 | return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / 54 | ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) 55 | 56 | @support_sz(2) 57 | def linear(x): 58 | fw, to_dtype, eps = set_framework_dependencies(x) 59 | return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * 60 | to_dtype((0 <= x) & (x <= 1))) 61 | 62 | @support_sz(1) 63 | def box(x): 64 | fw, to_dtype, eps = set_framework_dependencies(x) 65 | return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) 66 | -------------------------------------------------------------------------------- /model/utils/psconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class PyConv2d(nn.Module): 5 | """PyConv2d with padding (general case). Applies a 2D PyConv over an input signal composed of several input planes. 6 | Args: 7 | in_channels (int): Number of channels in the input image 8 | out_channels (list): Number of channels for each pyramid level produced by the convolution 9 | pyconv_kernels (list): Spatial size of the kernel for each pyramid level 10 | pyconv_groups (list): Number of blocked connections from input channels to output channels for each pyramid level 11 | stride (int or tuple, optional): Stride of the convolution. Default: 1 12 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 13 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False`` 14 | Example:: 15 | >>> # PyConv with two pyramid levels, kernels: 3x3, 5x5 16 | >>> m = PyConv2d(in_channels=64, out_channels=[32, 32], pyconv_kernels=[3, 5], pyconv_groups=[1, 4]) 17 | >>> input = torch.randn(4, 64, 56, 56) 18 | >>> output = m(input) 19 | >>> # PyConv with three pyramid levels, kernels: 3x3, 5x5, 7x7 20 | >>> m = PyConv2d(in_channels=64, out_channels=[16, 16, 32], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8]) 21 | >>> input = torch.randn(4, 64, 56, 56) 22 | >>> output = m(input) 23 | """ 24 | def __init__(self, in_channels, out_channels, pyconv_kernels, pyconv_groups, stride=1, dilation=1, bias=False): 25 | super(PyConv2d, self).__init__() 26 | 27 | assert len(out_channels) == len(pyconv_kernels) == len(pyconv_groups) 28 | 29 | self.pyconv_levels = [None] * len(pyconv_kernels) 30 | for i in range(len(pyconv_kernels)): 31 | self.pyconv_levels[i] = nn.Conv2d(in_channels, out_channels[i], kernel_size=pyconv_kernels[i], 32 | stride=stride, padding=pyconv_kernels[i] // 2, groups=pyconv_groups[i], 33 | dilation=dilation, bias=bias) 34 | self.pyconv_levels = nn.ModuleList(self.pyconv_levels) 35 | 36 | def forward(self, x): 37 | out = [] 38 | for level in self.pyconv_levels: 39 | out.append(level(x)) 40 | 41 | return torch.cat(out, 1) 42 | 43 | ################################################################ 44 | 45 | class PSConv2d(nn.Module): 46 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, parts=4, bias=False): 47 | super(PSConv2d, self).__init__() 48 | self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, dilation, dilation, groups=parts, bias=bias) 49 | self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * dilation, 2 * dilation, groups=parts, bias=bias) 50 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 51 | 52 | def backward_hook(grad): 53 | out = grad.clone() 54 | out[self.mask] = 0 55 | return out 56 | 57 | self.mask = torch.zeros(self.conv.weight.shape).byte().cuda() 58 | _in_channels = in_channels // parts 59 | _out_channels = out_channels // parts 60 | for i in range(parts): 61 | self.mask[i * _out_channels: (i + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1 62 | self.mask[(i + parts//2)%parts * _out_channels: ((i + parts//2)%parts + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1 63 | self.conv.weight.data[self.mask] = 0 64 | self.conv.weight.register_hook(backward_hook) 65 | 66 | self.weight = self.conv.weight 67 | self.bias = self.conv.bias 68 | 69 | def forward(self, x): 70 | x1, x2 = x.chunk(2, dim=1) 71 | x_shift = self.gwconv_shift(torch.cat((x2, x1), dim=1)) 72 | return self.gwconv(x) + self.conv(x) + x_shift 73 | 74 | 75 | # PSConv-based Group Convolution 76 | class PSGConv2d(nn.Module): 77 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, parts=4, bias=False): 78 | super(PSGConv2d, self).__init__() 79 | self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups * parts, bias=bias) 80 | self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * padding, 2 * dilation, groups=groups * parts, bias=bias) 81 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias) 82 | 83 | def backward_hook(grad): 84 | out = grad.clone() 85 | out[self.mask] = 0 86 | return out 87 | 88 | self.mask = torch.zeros(self.conv.weight.shape).bool().cuda() 89 | _in_channels = in_channels // (groups * parts) 90 | _out_channels = out_channels // (groups * parts) 91 | for i in range(parts): 92 | for j in range(groups): 93 | self.mask[(i + j * groups) * _out_channels: (i + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1 94 | self.mask[((i + parts // 2) % parts + j * groups) * _out_channels: ((i + parts // 2) % parts + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1 95 | self.conv.weight.data[self.mask] = 0 96 | self.conv.weight.register_hook(backward_hook) 97 | self.groups = groups 98 | 99 | self.weight = self.conv.weight 100 | self.bias = self.conv.bias 101 | 102 | def forward(self, x): 103 | x_split = (z.chunk(2, dim=1) for z in x.chunk(self.groups, dim=1)) 104 | x_merge = torch.cat(tuple(torch.cat((x2, x1), dim=1) for (x1, x2) in x_split), dim=1) 105 | x_shift = self.gwconv_shift(x_merge) 106 | gx = self.gwconv(x) 107 | cx = self.conv(x) 108 | # print(x.shape, gx.shape, cx.shape, x_merge.shape, x_shift.shape) 109 | return gx + cx + x_shift 110 | 111 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 4 | 5 | parser.add_argument('--n_resblocks', type=int, default=16, 6 | help='number of residual blocks') 7 | parser.add_argument('--n_feats', type=int, default=64, 8 | help='number of feature maps') 9 | parser.add_argument('--n_colors', type=int, default=3, 10 | help='number of color channels to use') 11 | parser.add_argument('--lr', type=float, default=1e-4, 12 | help='learning rate') 13 | parser.add_argument('--burst_size', type=int, default=14, 14 | help='burst size, max 14') 15 | parser.add_argument('--burst_channel', type=int, default=1, 16 | help='burst size, max 14') 17 | parser.add_argument('--sift_lr', action='store_true', 18 | help='use sift to pre-align burst frames') 19 | parser.add_argument('--lrcn', action='store_true', 20 | help='use long-range concatenating network') 21 | 22 | # Hardware specifications 23 | parser.add_argument('--n_threads', type=int, default=6, 24 | help='number of threads for data loading') 25 | parser.add_argument('--cpu', action='store_true', 26 | help='use cpu only') 27 | parser.add_argument('--n_GPUs', type=int, default=2, 28 | help='number of GPUs') 29 | parser.add_argument('--seed', type=int, default=1, 30 | help='random seed') 31 | parser.add_argument('--local_rank', type=int, default=-1, 32 | help='proc index') 33 | parser.add_argument('--fp16', action='store_true', 34 | help='use fp16 only') 35 | parser.add_argument('--load_head', action='store_true', 36 | help='load head from other model') 37 | parser.add_argument('--load_sr', action='store_true', 38 | help='load sr module from other model') 39 | parser.add_argument('--finetune_head', action='store_true', 40 | help='load head from other model') 41 | parser.add_argument('--finetune_large', action='store_true', 42 | help='load head from other model') 43 | parser.add_argument('--finetune_large_skip', action='store_true', 44 | help='load head from other model') 45 | parser.add_argument('--finetune_pcd', action='store_true', 46 | help='load head from other model') 47 | parser.add_argument('--use_tree', action='store_true', 48 | help='load head from other model') 49 | 50 | # Data specifications 51 | parser.add_argument('--root', type=str, default='/data/dataset/ntire21/burstsr/synthetic', 52 | help='dataset directory') 53 | parser.add_argument('--mode', type=str, default='train', 54 | help='demo image directory') 55 | parser.add_argument('--scale', type=str, default='4', 56 | help='super resolution scale') 57 | parser.add_argument('--patch_size', type=int, default=256, 58 | help='output patch size') 59 | parser.add_argument('--rgb_range', type=int, default=1, 60 | help='maximum value of RGB') 61 | 62 | parser.add_argument('--chop', action='store_true', 63 | help='enable memory-efficient forward') 64 | parser.add_argument('--no_augment', action='store_true', 65 | help='do not use data augmentation') 66 | 67 | # Model specifications 68 | parser.add_argument('--model', default='LRSC_EDVR', 69 | help='model name') 70 | 71 | parser.add_argument('--act', type=str, default='relu', 72 | help='activation function') 73 | parser.add_argument('--pre_train', type=str, default='', 74 | help='pre-trained model directory') 75 | parser.add_argument('--extend', type=str, default='.', 76 | help='pre-trained model directory') 77 | 78 | parser.add_argument('--res_scale', type=float, default=1, 79 | help='residual scaling') 80 | parser.add_argument('--shift_mean', default=True, 81 | help='subtract pixel mean from the input') 82 | parser.add_argument('--dilation', action='store_true', 83 | help='use dilated convolution') 84 | parser.add_argument('--precision', type=str, default='single', 85 | choices=('single', 'half'), 86 | help='FP precision for test (single | half)') 87 | 88 | 89 | # Option for Residual channel attention network (RCAN) 90 | parser.add_argument('--n_resgroups', type=int, default=20, 91 | help='number of residual groups') 92 | parser.add_argument('--reduction', type=int, default=16, 93 | help='number of feature maps reduction') 94 | parser.add_argument('--DA', action='store_true', 95 | help='use Dual Attention') 96 | parser.add_argument('--CA', action='store_true', 97 | help='use Channel Attention') 98 | parser.add_argument('--non_local', action='store_true', 99 | help='use Dual Attention') 100 | 101 | # Training specifications 102 | parser.add_argument('--reset', action='store_true', 103 | help='reset the training') 104 | parser.add_argument('--test_every', type=int, default=1000, 105 | help='do test per every N batches') 106 | parser.add_argument('--epochs', type=int, default=602, 107 | help='number of epochs to train') 108 | parser.add_argument('--batch_size', type=int, default=8, 109 | help='input batch size for training') 110 | parser.add_argument('--split_batch', type=int, default=1, 111 | help='split the batch into smaller chunks') 112 | parser.add_argument('--self_ensemble', action='store_true', 113 | help='use self-ensemble method for test') 114 | parser.add_argument('--test_only', action='store_true', 115 | help='set this option to test the model') 116 | parser.add_argument('--gan_k', type=int, default=1, 117 | help='k value for adversarial loss') 118 | 119 | # Optimization specifications 120 | 121 | parser.add_argument('--decay', type=str, default='150-250', 122 | help='learning rate decay type') 123 | parser.add_argument('--gamma', type=float, default=0.5, 124 | help='learning rate decay factor for step decay') 125 | parser.add_argument('--optimizer', default='ADAM', 126 | choices=('SGD', 'ADAM', 'RMSprop'), 127 | help='optimizer to use (SGD | ADAM | RMSprop)') 128 | parser.add_argument('--momentum', type=float, default=0.9, 129 | help='SGD momentum') 130 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), 131 | help='ADAM beta') 132 | parser.add_argument('--epsilon', type=float, default=1e-8, 133 | help='ADAM epsilon for numerical stability') 134 | parser.add_argument('--weight_decay', type=float, default=0, 135 | help='weight decay') 136 | parser.add_argument('--gclip', type=float, default=0, 137 | help='gradient clipping threshold (0 = no clipping)') 138 | 139 | # Loss specifications 140 | parser.add_argument('--loss', type=str, default='1*L1', 141 | help='loss function configuration') 142 | parser.add_argument('--skip_threshold', type=float, default='1e8', 143 | help='skipping batch that has large error') 144 | 145 | # Log specifications 146 | parser.add_argument('--save', type=str, default='test', 147 | help='file name to save') 148 | parser.add_argument('--load', type=str, default='', 149 | help='file name to load') 150 | parser.add_argument('--resume', type=int, default=0, 151 | help='resume from specific checkpoint') 152 | parser.add_argument('--save_models', action='store_true', 153 | help='save all intermediate models') 154 | parser.add_argument('--print_every', type=int, default=1, 155 | help='how many batches to wait before logging training status') 156 | parser.add_argument('--save_results', action='store_true', 157 | help='save output results') 158 | parser.add_argument('--save_gt', action='store_true', 159 | help='save low-resolution and high-resolution images together') 160 | 161 | args = parser.parse_args() 162 | 163 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 164 | 165 | if args.epochs == 0: 166 | args.epochs = 1e8 167 | 168 | for arg in vars(args): 169 | if vars(args)[arg] == 'True': 170 | vars(args)[arg] = True 171 | elif vars(args)[arg] == 'False': 172 | vars(args)[arg] = False 173 | 174 | -------------------------------------------------------------------------------- /pwcnet/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-pwc 2 | This is a personal reimplementation of PWC-Net [1] using PyTorch. Should you be making use of this work, please cite the paper accordingly. Also, make sure to adhere to the licensing terms of the authors. Should you be making use of this particular implementation, please acknowledge it appropriately [2]. 3 | 4 | Paper 5 | 6 | For the original version of this work, please see: https://github.com/NVlabs/PWC-Net 7 |
8 | Another optical flow implementation from me: https://github.com/sniklaus/pytorch-liteflownet 9 |
10 | And another optical flow implementation from me: https://github.com/sniklaus/pytorch-unflow 11 |
12 | Yet another optical flow implementation from me: https://github.com/sniklaus/pytorch-spynet 13 | 14 | ## background 15 | The authors of PWC-Net are thankfully already providing a reference implementation in PyTorch. However, its initial version did not reach the performance of the original Caffe version. This is why I created this repositroy, in which I replicated the performance of the official Caffe version by utilizing its weights. 16 | 17 | The official PyTorch implementation has adopted my approach of using the Caffe weights since then, which is why they are all performing equally well now. Many people have reported issues with CUDA when trying to get the official PyTorch version to run though, while my reimplementaiton does not seem to be subject to such problems. 18 | 19 | ## setup 20 | To download the pre-trained models, run `bash download.bash`. These originate from the original authors, I just converted them to PyTorch. 21 | 22 | The correlation layer is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using `pip install cupy` or alternatively using one of the provided binary packages as outlined in the CuPy repository. 23 | 24 | ## usage 25 | To run it on your own pair of images, use the following command. You can choose between two models, please make sure to see their paper / the code for more details. 26 | 27 | ``` 28 | python run.py --model default --first ./images/first.png --second ./images/second.png --out ./out.flo 29 | ``` 30 | 31 | I am afraid that I cannot guarantee that this reimplementation is correct. However, it produced results identical to the Caffe implementation of the original authors in the examples that I tried. Please feel free to contribute to this repository by submitting issues and pull requests. 32 | 33 | ## comparison 34 |

Comparison

35 | 36 | ## license 37 | As stated in the licensing terms of the authors of the paper, the models are free for non-commercial share-alike purpose. Please make sure to further consult their licensing terms. 38 | 39 | ## references 40 | ``` 41 | [1] @inproceedings{Sun_CVPR_2018, 42 | author = {Deqing Sun and Xiaodong Yang and Ming-Yu Liu and Jan Kautz}, 43 | title = {{PWC-Net}: {CNNs} for Optical Flow Using Pyramid, Warping, and Cost Volume}, 44 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, 45 | year = {2018} 46 | } 47 | ``` 48 | 49 | ``` 50 | [2] @misc{pytorch-pwc, 51 | author = {Simon Niklaus}, 52 | title = {A Reimplementation of {PWC-Net} Using {PyTorch}}, 53 | year = {2018}, 54 | howpublished = {\url{https://github.com/sniklaus/pytorch-pwc}} 55 | } 56 | ``` -------------------------------------------------------------------------------- /pwcnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/__init__.py -------------------------------------------------------------------------------- /pwcnet/comparison/comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/comparison/comparison.gif -------------------------------------------------------------------------------- /pwcnet/comparison/comparison.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import math 4 | import moviepy 5 | import moviepy.editor 6 | import numpy 7 | import PIL 8 | import PIL.Image 9 | import PIL.ImageFont 10 | import PIL.ImageDraw 11 | 12 | intX = 32 13 | intY = 436 - 64 14 | 15 | objImages = [ { 16 | 'strFile': 'official - caffe.png', 17 | 'strText': 'official - Caffe' 18 | }, { 19 | 'strFile': 'this - pytorch.png', 20 | 'strText': 'this - PyTorch' 21 | } ] 22 | 23 | npyImages = [] 24 | 25 | for objImage in objImages: 26 | objOutput = PIL.Image.open(objImage['strFile']).convert('RGB') 27 | 28 | for intU in [ intShift - 10 for intShift in range(20) ]: 29 | for intV in [ intShift - 10 for intShift in range(20) ]: 30 | if math.sqrt(math.pow(intU, 2.0) + math.pow(intV, 2.0)) <= 5.0: 31 | PIL.ImageDraw.Draw(objOutput).text((intX + intU, intY + intV), objImage['strText'], (255, 255, 255), PIL.ImageFont.truetype('freefont/FreeSerifBold.ttf', 32)) 32 | # end 33 | # end 34 | # end 35 | 36 | PIL.ImageDraw.Draw(objOutput).text((intX, intY), objImage['strText'], (0, 0, 0), PIL.ImageFont.truetype('freefont/FreeSerifBold.ttf', 32)) 37 | 38 | npyImages.append(numpy.array(objOutput)) 39 | # end 40 | 41 | moviepy.editor.ImageSequenceClip(sequence=npyImages, fps=1).write_gif(filename='comparison.gif', program='ImageMagick', opt='optimizeplus') -------------------------------------------------------------------------------- /pwcnet/comparison/official - caffe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/comparison/official - caffe.png -------------------------------------------------------------------------------- /pwcnet/comparison/this - pytorch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/comparison/this - pytorch.png -------------------------------------------------------------------------------- /pwcnet/correlation/README.md: -------------------------------------------------------------------------------- 1 | This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately. -------------------------------------------------------------------------------- /pwcnet/correlation/__pycache__/correlation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/correlation/__pycache__/correlation.cpython-37.pyc -------------------------------------------------------------------------------- /pwcnet/download.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget --verbose --continue --timestamping http://content.sniklaus.com/github/pytorch-pwc/network-chairs-things.pytorch 4 | wget --verbose --continue --timestamping http://content.sniklaus.com/github/pytorch-pwc/network-default.pytorch -------------------------------------------------------------------------------- /pwcnet/images/README.md: -------------------------------------------------------------------------------- 1 | The used example originates from the MPI Sintel dataset: http://sintel.is.tue.mpg.de/ -------------------------------------------------------------------------------- /pwcnet/images/first.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/images/first.png -------------------------------------------------------------------------------- /pwcnet/images/second.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/images/second.png -------------------------------------------------------------------------------- /pwcnet/out.flo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/out.flo -------------------------------------------------------------------------------- /pwcnet/requirements.txt: -------------------------------------------------------------------------------- 1 | cupy>=5.0.0 2 | numpy>=1.15.0 3 | Pillow>=5.0.0 4 | torch>=1.3.0 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | imageio 3 | opencv-python 4 | tensorboardX 5 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/cal_mean_std.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | from datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image 6 | from datasets.synthetic_burst_train_set import SyntheticBurst 7 | from datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB 8 | 9 | def main(): 10 | train_zurich_raw2rgb = ZurichRAW2RGB(root='/data/dataset/ntire21/burstsr/synthetic', split='train') 11 | train_data = SyntheticBurst(train_zurich_raw2rgb, burst_size=14, crop_sz=384) 12 | means = [] 13 | stds = [] 14 | 15 | for data in tqdm(train_data): 16 | print(data.shape) 17 | break 18 | 19 | 20 | if __name__ == '__main__': 21 | # if not args.cpu: torch.cuda.set_device(0) 22 | main() 23 | -------------------------------------------------------------------------------- /scripts/demo.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | rlaunch --cpu=4 --gpu=1 --memory=10240 -- python ./scripts/evaluate_burstsr_val.py 3 | -------------------------------------------------------------------------------- /scripts/download_burstsr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | import zipfile 4 | import shutil 5 | import argparse 6 | 7 | 8 | def download_burstsr_dataset(download_path): 9 | out_dir = download_path + '/burstsr_dataset' 10 | 11 | # Download train folders 12 | for i in range(9): 13 | if not os.path.isfile('{}/train_{:02d}.zip'.format(out_dir, i)): 14 | print('Downloading train_{:02d}'.format(i)) 15 | 16 | urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/train_{:02d}.zip'.format(i), 17 | '{}/tmp.zip'.format(out_dir)) 18 | 19 | os.rename('{}/tmp.zip'.format(out_dir), '{}/train_{:02d}.zip'.format(out_dir, i)) 20 | 21 | # Download val folder 22 | if not os.path.isfile('{}/val.zip'.format(out_dir)): 23 | print('Downloading val') 24 | 25 | urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/val.zip', 26 | '{}/tmp.zip'.format(out_dir)) 27 | 28 | os.rename('{}/tmp.zip'.format(out_dir), '{}/val.zip'.format(out_dir)) 29 | 30 | # Unpack train set 31 | for i in range(9): 32 | print('Unpacking train_{:02d}'.format(i)) 33 | with zipfile.ZipFile('{}/train_{:02d}.zip'.format(out_dir, i), 'r') as zip_ref: 34 | zip_ref.extractall('{}'.format(out_dir)) 35 | 36 | # Move files to a common directory 37 | os.makedirs('{}/train'.format(out_dir), exist_ok=True) 38 | 39 | for i in range(9): 40 | file_list = os.listdir('{}/train_{:02d}'.format(out_dir, i)) 41 | 42 | for b in file_list: 43 | source_dir = '{}/train_{:02d}/{}'.format(out_dir, i, b) 44 | dst_dir = '{}/train/{}'.format(out_dir, b) 45 | 46 | if os.path.isdir(source_dir): 47 | shutil.move(source_dir, dst_dir) 48 | 49 | # Delete individual subsets 50 | for i in range(9): 51 | shutil.rmtree('{}/train_{:02d}'.format(out_dir, i)) 52 | 53 | # Unpack val set 54 | print('Unpacking val') 55 | with zipfile.ZipFile('{}/val.zip'.format(out_dir), 'r') as zip_ref: 56 | zip_ref.extractall('{}'.format(out_dir)) 57 | 58 | 59 | def main(): 60 | parser = argparse.ArgumentParser(description='Downloads and unpacks BurstSR dataset') 61 | parser.add_argument('path', type=str, help='Path where the dataset will be downloaded') 62 | 63 | args = parser.parse_args() 64 | 65 | download_burstsr_dataset(args.path) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | 71 | 72 | -------------------------------------------------------------------------------- /scripts/evaluate.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | rlaunch --cpu=4 --gpu=1 --memory=10240 -- python scripts/evaluate_burstsr_val.py 3 | -------------------------------------------------------------------------------- /scripts/evaluate_burstsr_val.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from datasets.burstsr_dataset import BurstSRDataset 3 | from utils.metrics import AlignedPSNR 4 | from pwcnet.pwcnet import PWCNet 5 | 6 | root = '/data/dataset/ntire21/burstsr/real/NTIRE/burstsr_dataset' 7 | 8 | class SimpleBaseline: 9 | def __init__(self): 10 | pass 11 | 12 | def __call__(self, burst): 13 | burst_rgb = burst[:, 0, [0, 1, 3]] 14 | burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) 15 | burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') 16 | return burst_rgb 17 | 18 | 19 | def main(): 20 | # Load dataset 21 | dataset = BurstSRDataset(root=root, 22 | split='val', burst_size=14, crop_sz=80, random_flip=False) 23 | 24 | # TODO Set your network here 25 | net = SimpleBaseline() 26 | 27 | device = 'cuda' 28 | 29 | # Load alignment network, used in AlignedPSNR 30 | alignment_net = PWCNet(load_pretrained=True, 31 | weights_path='PATH_TO_PWCNET_WEIGHTS') 32 | alignment_net = alignment_net.to(device) 33 | aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40) 34 | 35 | scores_all = [] 36 | for idx in range(len(dataset)): 37 | burst, frame_gt, meta_info_burst, meta_info_gt = dataset[idx] 38 | burst = burst.unsqueeze(0).to(device) 39 | frame_gt = frame_gt.unsqueeze(0).to(device) 40 | 41 | net_pred = net(burst) 42 | 43 | # Calculate Aligned PSNR 44 | score = aligned_psnr_fn(net_pred, frame_gt, burst) 45 | 46 | scores_all.append(score) 47 | 48 | mean_psnr = sum(scores_all) / len(scores_all) 49 | 50 | print('Mean PSNR is {:0.3f}'.format(mean_psnr.item())) 51 | 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /scripts/save_results_synburst_val.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import cv2 3 | from datasets.synthetic_burst_val_set import SyntheticBurstVal 4 | import torch 5 | import numpy as np 6 | import os 7 | 8 | 9 | class SimpleBaseline: 10 | def __init__(self): 11 | pass 12 | 13 | def __call__(self, burst): 14 | burst_rgb = burst[:, 0, [0, 1, 3]] 15 | burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) 16 | burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') 17 | return burst_rgb 18 | 19 | 20 | def main(): 21 | dataset = SyntheticBurstVal('PATH_TO_SyntheticBurstVal') 22 | out_dir = 'PATH_WHERE_RESULTS_ARE_SAVED' 23 | 24 | # TODO Set your network here 25 | net = SimpleBaseline() 26 | 27 | device = 'cuda' 28 | os.makedirs(out_dir, exist_ok=True) 29 | 30 | for idx in range(len(dataset)): 31 | burst, burst_name = dataset[idx] 32 | 33 | burst = burst.to(device).unsqueeze(0) 34 | 35 | with torch.no_grad(): 36 | net_pred = net(burst) 37 | 38 | # Normalize to 0 2^14 range and convert to numpy array 39 | net_pred_np = (net_pred.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16) 40 | 41 | # Save predictions as png 42 | cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /scripts/test_burstsr_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import cv2 3 | from datasets.burstsr_dataset import BurstSRDataset 4 | from torch.utils.data.dataloader import DataLoader 5 | from utils.metrics import AlignedPSNR 6 | from utils.postprocessing_functions import BurstSRPostProcess 7 | from utils.data_format_utils import convert_dict 8 | from pwcnet.pwcnet import PWCNet 9 | 10 | 11 | def main(): 12 | # Load dataset 13 | dataset = BurstSRDataset(root='PATH_TO_BURST_SR', 14 | split='val', burst_size=3, crop_sz=56, random_flip=False) 15 | 16 | data_loader = DataLoader(dataset, batch_size=2) 17 | 18 | # Load alignment network, used in AlignedPSNR 19 | alignment_net = PWCNet(load_pretrained=True, 20 | weights_path='PATH_TO_PWCNET_WEIGHTS') 21 | alignment_net = alignment_net.to('cuda') 22 | 23 | aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40) 24 | 25 | # Postprocessing function to obtain sRGB images 26 | postprocess_fn = BurstSRPostProcess(return_np=True) 27 | 28 | for d in data_loader: 29 | burst, frame_gt, meta_info_burst, meta_info_gt = d 30 | 31 | # A simple baseline which upsamples the base image using bilinear upsampling 32 | burst_rgb = burst[:, 0, [0, 1, 3]] 33 | burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) 34 | burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') 35 | 36 | # Calculate Aligned PSNR 37 | score = aligned_psnr_fn(burst_rgb.cuda(), frame_gt.cuda(), burst.cuda()) 38 | print('PSNR is {:0.3f}'.format(score)) 39 | 40 | meta_info_gt = convert_dict(meta_info_gt, burst.shape[0]) 41 | 42 | # Apply simple post-processing to obtain RGB images 43 | pred_0 = postprocess_fn.process(burst_rgb[0], meta_info_gt[0]) 44 | gt_0 = postprocess_fn.process(frame_gt[0], meta_info_gt[0]) 45 | 46 | pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR) 47 | gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR) 48 | 49 | # Visualize input, ground truth 50 | cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0) 51 | cv2.imshow('GT', gt_0) 52 | 53 | input_key = cv2.waitKey(0) 54 | if input_key == ord('q'): 55 | return 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /scripts/test_synthetic_bursts.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import cv2 3 | from datasets.synthetic_burst_train_set import SyntheticBurst 4 | from torch.utils.data.dataloader import DataLoader 5 | from utils.metrics import PSNR 6 | from utils.postprocessing_functions import SimplePostProcess 7 | from utils.data_format_utils import convert_dict 8 | from datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB 9 | 10 | 11 | def main(): 12 | zurich_raw2rgb = ZurichRAW2RGB(root='PATH_TO_ZURICH_RAW_TO_RGB', split='test') 13 | dataset = SyntheticBurst(zurich_raw2rgb, burst_size=3, crop_sz=256) 14 | 15 | data_loader = DataLoader(dataset, batch_size=2) 16 | 17 | # Function to calculate PSNR. Note that the boundary pixels (40 pixels) will be ignored during PSNR computation 18 | psnr_fn = PSNR(boundary_ignore=40) 19 | 20 | # Postprocessing function to obtain sRGB images 21 | postprocess_fn = SimplePostProcess(return_np=True) 22 | 23 | for d in data_loader: 24 | burst, frame_gt, flow_vectors, meta_info = d 25 | 26 | # A simple baseline which upsamples the base image using bilinear upsampling 27 | burst_rgb = burst[:, 0, [0, 1, 3]] 28 | burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) 29 | burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') 30 | 31 | # Calculate PSNR 32 | score = psnr_fn(burst_rgb, frame_gt) 33 | 34 | print('PSNR is {:0.3f}'.format(score)) 35 | 36 | meta_info = convert_dict(meta_info, burst.shape[0]) 37 | 38 | # Apply simple post-processing to obtain RGB images 39 | pred_0 = postprocess_fn.process(burst_rgb[0], meta_info[0]) 40 | gt_0 = postprocess_fn.process(frame_gt[0], meta_info[0]) 41 | 42 | pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR) 43 | gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR) 44 | 45 | # Visualize input, ground truth 46 | cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0) 47 | cv2.imshow('GT', gt_0) 48 | 49 | input_key = cv2.waitKey(0) 50 | if input_key == ord('q'): 51 | return 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import os 6 | from tqdm import tqdm 7 | import random 8 | import utility 9 | from option import args 10 | 11 | from datasets.synthetic_burst_val_set import SyntheticBurstVal 12 | from datasets.burstsr_dataset import flatten_raw_image_batch 13 | import model 14 | 15 | import torch.multiprocessing as mp 16 | import torch.backends.cudnn as cudnn 17 | import torch.utils.data.distributed 18 | import time 19 | 20 | 21 | checkpoint = utility.checkpoint(args) 22 | 23 | def sample_images(burst_size=14): 24 | _burst_size = 14 25 | 26 | ids = random.sample(range(1, _burst_size), k=burst_size - 1) 27 | ids = [0, ] + ids 28 | return ids 29 | 30 | 31 | def ttaup(burst): 32 | burst0 = burst.clone() 33 | burst0 = flatten_raw_image_batch(burst0.unsqueeze(0)).cuda() 34 | 35 | burst3 = burst0.clone().permute(0, 1, 2, 4, 3).cuda() 36 | 37 | ids = sample_images(burst.shape[0]) 38 | burst4 = burst0[:, ids].clone() 39 | 40 | return burst0, burst3, burst4 41 | 42 | 43 | def ttadown(bursts): 44 | burst0 = bursts[0] 45 | 46 | burst3 = bursts[1].permute(0, 1, 3, 2) 47 | burst4 = bursts[2] 48 | 49 | out = (burst0 + burst3 + burst4) / 3 50 | return out 51 | 52 | 53 | def main(): 54 | mp.spawn(main_worker, nprocs=1, args=(1, args)) 55 | 56 | 57 | def main_worker(local_rank, nprocs, args): 58 | 59 | cudnn.benchmark = True 60 | args.local_rank = local_rank 61 | utility.setup(local_rank, nprocs) 62 | torch.cuda.set_device(local_rank) 63 | 64 | 65 | dataset = SyntheticBurstVal(args.root) 66 | out_dir = 'val' 67 | 68 | _model = model.Model(args, checkpoint) 69 | 70 | os.makedirs(out_dir, exist_ok=True) 71 | 72 | tt = [] 73 | for idx in tqdm(range(len(dataset))): 74 | burst, burst_name = dataset[idx] 75 | bursts = ttaup(burst) 76 | 77 | srs = [] 78 | with torch.no_grad(): 79 | for x in bursts: 80 | tic = time.time() 81 | sr = _model(x, 0) 82 | toc = time.time() 83 | tt.append(toc-tic) 84 | srs.append(sr) 85 | 86 | sr = ttadown(srs) 87 | # Normalize to 0 2^14 range and convert to numpy array 88 | net_pred_np = (sr.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16) 89 | cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np) 90 | 91 | print('avg time: {:.4f}'.format(np.mean(tt))) 92 | utility.cleanup() 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /test_real.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import os 6 | from tqdm import tqdm 7 | import random 8 | import utility 9 | from option import args 10 | import torchvision.utils as tvutils 11 | from pwcnet.pwcnet import PWCNet 12 | 13 | from utils.postprocessing_functions import BurstSRPostProcess 14 | from datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image_batch, pack_raw_image 15 | from utils.metrics import AlignedPSNR 16 | from utils.data_format_utils import convert_dict 17 | from data_processing.camera_pipeline import demosaic 18 | import model 19 | 20 | import torch.multiprocessing as mp 21 | import torch.backends.cudnn as cudnn 22 | import torch.utils.data.distributed 23 | import time 24 | 25 | from torchsummaryX import summary 26 | 27 | 28 | checkpoint = utility.checkpoint(args) 29 | 30 | 31 | def main(): 32 | mp.spawn(main_worker, nprocs=1, args=(1, args)) 33 | 34 | 35 | def main_worker(local_rank, nprocs, args): 36 | cudnn.benchmark = True 37 | args.local_rank = local_rank 38 | utility.setup(local_rank, nprocs) 39 | torch.cuda.set_device(local_rank) 40 | 41 | dataset = BurstSRDataset(root=args.root, burst_size=14, crop_sz=80, split='val') 42 | out_dir = 'val/ebsr_real' 43 | 44 | _model = model.Model(args, checkpoint) 45 | 46 | for param in _model.parameters(): 47 | param.requires_grad = False 48 | 49 | alignment_net = PWCNet(load_pretrained=True, 50 | weights_path='./pwcnet/pwcnet-network-default.pth') 51 | alignment_net = alignment_net.to('cuda') 52 | for param in alignment_net.parameters(): 53 | param.requires_grad = False 54 | 55 | aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40) 56 | 57 | postprocess_fn = BurstSRPostProcess(return_np=True) 58 | 59 | os.makedirs(out_dir, exist_ok=True) 60 | 61 | tt = [] 62 | psnrs, ssims, lpipss = [], [], [] 63 | for idx in tqdm(range(len(dataset))): 64 | burst_, gt, meta_info_burst, meta_info_gt = dataset[idx] 65 | burst_ = burst_.unsqueeze(0).cuda() 66 | gt = gt.unsqueeze(0).cuda() 67 | burst = flatten_raw_image_batch(burst_) 68 | 69 | with torch.no_grad(): 70 | tic = time.time() 71 | sr = _model(burst, 0) 72 | toc = time.time() 73 | tt.append(toc-tic) 74 | 75 | sr_int = (sr.clamp(0.0, 1.0) * 2 ** 14).short() 76 | sr = sr_int.float() / (2 ** 14) 77 | 78 | psnr, ssim, lpips = aligned_psnr_fn(sr, gt, burst_) 79 | psnrs.append(psnr.item()) 80 | ssims.append(ssim.item()) 81 | lpipss.append(lpips.item()) 82 | 83 | os.makedirs(f'{out_dir}/{idx}', exist_ok=True) 84 | sr_ = postprocess_fn.process(sr[0], meta_info_burst) 85 | sr_ = cv2.cvtColor(sr_, cv2.COLOR_RGB2BGR) 86 | cv2.imwrite('{}/{}_sr.png'.format(out_dir, idx), sr_) 87 | 88 | del burst 89 | del sr 90 | del gt 91 | 92 | 93 | print(f'avg PSNR: {np.mean(psnrs):.6f}') 94 | print(f'avg SSIM: {np.mean(ssims):.6f}') 95 | print(f'avg LPIPS: {np.mean(lpipss):.6f}') 96 | print(f' avg time: {np.mean(tt):.6f}') 97 | 98 | # utility.cleanup() 99 | 100 | 101 | if __name__ == '__main__': 102 | main() 103 | -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import datetime 4 | from multiprocessing import Process 5 | from multiprocessing import Queue 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | import numpy as np 10 | import imageio 11 | import os 12 | import sys 13 | 14 | import torch 15 | import torch.optim as optim 16 | import torch.optim.lr_scheduler as lrs 17 | 18 | import torch.distributed as dist 19 | import matplotlib 20 | 21 | matplotlib.use('Agg') 22 | 23 | 24 | def reduce_mean(tensor, nprocs): 25 | rt = tensor.clone() 26 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 27 | rt /= nprocs 28 | return rt 29 | 30 | 31 | def setup(rank, world_size): 32 | if sys.platform == 'win32': 33 | # Distributed package only covers collective communications with Gloo 34 | # backend and FileStore on Windows platform. Set init_method parameter 35 | # in init_process_group to a local file. 36 | # Example init_method="file:///f:/libtmp/some_file" 37 | init_method = "tcp://localhost:1234" 38 | 39 | # initialize the process group 40 | dist.init_process_group( 41 | "gloo", 42 | init_method=init_method, 43 | rank=rank, 44 | world_size=world_size 45 | ) 46 | else: 47 | os.environ['MASTER_ADDR'] = 'localhost' 48 | os.environ['MASTER_PORT'] = '12355' 49 | 50 | # initialize the process group 51 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 52 | 53 | 54 | def cleanup(): 55 | dist.destroy_process_group() 56 | 57 | 58 | def mkdir(path): 59 | if not os.path.exists(path): 60 | os.makedirs(path) 61 | 62 | 63 | class timer(): 64 | def __init__(self): 65 | self.acc = 0 66 | self.tic() 67 | 68 | def tic(self): 69 | self.t0 = time.time() 70 | 71 | def toc(self, restart=False): 72 | diff = time.time() - self.t0 73 | if restart: self.t0 = time.time() 74 | return diff 75 | 76 | def hold(self): 77 | self.acc += self.toc() 78 | 79 | def release(self): 80 | ret = self.acc 81 | self.acc = 0 82 | 83 | return ret 84 | 85 | def reset(self): 86 | self.acc = 0 87 | 88 | 89 | class checkpoint(): 90 | def __init__(self, args): 91 | self.args = args 92 | self.ok = True 93 | self.log = torch.Tensor() 94 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 95 | 96 | if not args.load: 97 | if not args.save: 98 | args.save = now 99 | self.dir = os.path.join('..', 'experiment', args.save) 100 | else: 101 | self.dir = os.path.join('..', 'experiment', args.load) 102 | if os.path.exists(self.dir): 103 | self.log = torch.load(self.get_path('psnr_log.pt')) 104 | print('Continue from epoch {}...'.format(len(self.log))) 105 | else: 106 | args.load = '' 107 | 108 | if args.reset: 109 | os.system('rm -rf ' + self.dir) 110 | args.load = '' 111 | 112 | os.makedirs(self.dir, exist_ok=True) 113 | os.makedirs(self.get_path('model'), exist_ok=True) 114 | # for d in args.data_test: 115 | # os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) 116 | 117 | open_type = 'a' if os.path.exists(self.get_path('log.txt')) else 'w' 118 | self.log_file = open(self.get_path('log.txt'), open_type) 119 | with open(self.get_path('config.txt'), open_type) as f: 120 | f.write(now + '\n\n') 121 | for arg in vars(args): 122 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 123 | f.write('\n') 124 | 125 | self.n_processes = 8 126 | 127 | def get_path(self, *subdir): 128 | return os.path.join(self.dir, *subdir) 129 | 130 | def save(self, trainer, epoch, is_best=False): 131 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best) 132 | trainer.loss.save(self.dir) 133 | trainer.loss.plot_loss(self.dir, epoch) 134 | 135 | self.plot_psnr(epoch) 136 | trainer.optimizer.save(self.dir) 137 | torch.save(self.log, self.get_path('psnr_log.pt')) 138 | 139 | def add_log(self, log): 140 | self.log = torch.cat([self.log, log]) 141 | 142 | def write_log(self, log, refresh=False): 143 | print(log) 144 | self.log_file.write(log + '\n') 145 | if refresh: 146 | self.log_file.close() 147 | self.log_file = open(self.get_path('log.txt'), 'a') 148 | 149 | def done(self): 150 | self.log_file.close() 151 | 152 | def plot_psnr(self, epoch): 153 | axis = np.linspace(1, epoch, epoch) 154 | for idx_data, d in enumerate(self.args.data_test): 155 | label = 'SR on {}'.format(d) 156 | fig = plt.figure() 157 | plt.title(label) 158 | for idx_scale, scale in enumerate(self.args.scale): 159 | plt.plot( 160 | axis, 161 | self.log[:, idx_data, idx_scale].numpy(), 162 | label='Scale {}'.format(scale) 163 | ) 164 | plt.legend() 165 | plt.xlabel('Epochs') 166 | plt.ylabel('PSNR') 167 | plt.grid(True) 168 | plt.savefig(self.get_path('test_{}.pdf'.format(d))) 169 | plt.close(fig) 170 | 171 | def begin_background(self): 172 | self.queue = Queue() 173 | 174 | def bg_target(queue): 175 | while True: 176 | if not queue.empty(): 177 | filename, tensor = queue.get() 178 | if filename is None: break 179 | imageio.imwrite(filename, tensor.numpy()) 180 | 181 | self.process = [ 182 | Process(target=bg_target, args=(self.queue,)) \ 183 | for _ in range(self.n_processes) 184 | ] 185 | 186 | for p in self.process: p.start() 187 | 188 | def end_background(self): 189 | for _ in range(self.n_processes): self.queue.put((None, None)) 190 | while not self.queue.empty(): time.sleep(1) 191 | for p in self.process: p.join() 192 | 193 | def save_results(self, dataset, filename, save_list, scale): 194 | if self.args.save_results: 195 | filename = self.get_path( 196 | 'results-{}'.format(dataset.dataset.name), 197 | '{}_x{}_'.format(filename, scale) 198 | ) 199 | 200 | postfix = ('SR', 'LR', 'HR') 201 | for v, p in zip(save_list, postfix): 202 | normalized = v[0].mul(255 / self.args.rgb_range) 203 | tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() 204 | self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) 205 | 206 | 207 | def quantize(img, rgb_range): 208 | pixel_range = 255 / rgb_range 209 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 210 | 211 | 212 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None): 213 | if hr.nelement() == 1: return 0 214 | 215 | diff = (sr - hr) / rgb_range 216 | if dataset and dataset.dataset.benchmark: 217 | shave = scale 218 | if diff.size(1) > 1: 219 | gray_coeffs = [65.738, 129.057, 25.064] 220 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 221 | diff = diff.mul(convert).sum(dim=1) 222 | else: 223 | shave = scale + 6 224 | 225 | valid = diff[..., shave:-shave, shave:-shave] 226 | mse = valid.pow(2).mean() 227 | 228 | return -10 * math.log10(mse) 229 | 230 | 231 | def make_optimizer(args, target): 232 | ''' 233 | make optimizer and scheduler together 234 | ''' 235 | # optimizer 236 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 237 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} 238 | 239 | if args.optimizer == 'SGD': 240 | optimizer_class = optim.SGD 241 | kwargs_optimizer['momentum'] = args.momentum 242 | elif args.optimizer == 'ADAM': 243 | optimizer_class = optim.Adam 244 | kwargs_optimizer['betas'] = args.betas 245 | kwargs_optimizer['eps'] = args.epsilon 246 | elif args.optimizer == 'RMSprop': 247 | optimizer_class = optim.RMSprop 248 | kwargs_optimizer['eps'] = args.epsilon 249 | 250 | # scheduler 251 | milestones = list(map(lambda x: int(x), args.decay.split('-'))) 252 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} 253 | scheduler_class = lrs.MultiStepLR 254 | 255 | class CustomOptimizer(optimizer_class): 256 | def __init__(self, *args, **kwargs): 257 | super(CustomOptimizer, self).__init__(*args, **kwargs) 258 | 259 | def _register_scheduler(self, scheduler_class, **kwargs): 260 | self.scheduler = scheduler_class(self, **kwargs) 261 | 262 | def save(self, save_dir): 263 | torch.save(self.state_dict(), self.get_dir(save_dir)) 264 | 265 | def load(self, load_dir, epoch=1): 266 | self.load_state_dict(torch.load(self.get_dir(load_dir))) 267 | if epoch > 1: 268 | for _ in range(epoch): self.scheduler.step() 269 | 270 | def get_dir(self, dir_path): 271 | return os.path.join(dir_path, 'optimizer.pt') 272 | 273 | def schedule(self): 274 | self.scheduler.step() 275 | 276 | def get_lr(self): 277 | return self.scheduler.get_last_lr()[0] 278 | 279 | def get_last_epoch(self): 280 | return self.scheduler.last_epoch 281 | 282 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer) 283 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) 284 | return optimizer 285 | 286 | 287 | def write_gray_to_tfboard(img): 288 | img_debug = img[0, ...].detach().cpu().numpy() 289 | 290 | # img_debug = cv2.normalize(img_debug, None, 0, 255, 291 | # cv2.NORM_MINMAX, cv2.CV_8U) 292 | img_debug = img_debug * 255 293 | img_debug = np.clip(img_debug, 0, 255) 294 | img_debug = img_debug.astype(np.uint8) 295 | return img_debug[0, ...] 296 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_format_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 as cv 4 | 5 | 6 | def numpy_to_torch(a: np.ndarray): 7 | return torch.from_numpy(a).float().permute(2, 0, 1) 8 | 9 | 10 | def torch_to_numpy(a: torch.Tensor): 11 | return a.permute(1, 2, 0).cpu().numpy() 12 | 13 | 14 | def torch_to_npimage(a: torch.Tensor, unnormalize=True): 15 | a_np = torch_to_numpy(a) 16 | 17 | if unnormalize: 18 | a_np = a_np * 255 19 | a_np = a_np.astype(np.uint8) 20 | return cv.cvtColor(a_np, cv.COLOR_RGB2BGR) 21 | 22 | 23 | def npimage_to_torch(a, normalize=True, input_bgr=True): 24 | if input_bgr: 25 | a = cv.cvtColor(a, cv.COLOR_BGR2RGB) 26 | a_t = numpy_to_torch(a) 27 | 28 | if normalize: 29 | a_t = a_t / 255.0 30 | 31 | return a_t 32 | 33 | 34 | def convert_dict(base_dict, batch_sz): 35 | out_dict = [] 36 | for b_elem in range(batch_sz): 37 | b_info = {} 38 | for k, v in base_dict.items(): 39 | if isinstance(v, (list, torch.Tensor)): 40 | b_info[k] = v[b_elem] 41 | out_dict.append(b_info) 42 | 43 | return out_dict -------------------------------------------------------------------------------- /utils/debayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional 4 | 5 | class Debayer3x3(torch.nn.Module): 6 | '''Demosaicing of Bayer images using 3x3 convolutions. 7 | 8 | Requires BG-Bayer color filter array layout. That is, 9 | the image[1,1]='B', image[1,2]='G'. This corresponds 10 | to OpenCV naming conventions. 11 | 12 | Compared to Debayer2x2 this method does not use upsampling. 13 | Instead, we identify five 3x3 interpolation kernels that 14 | are sufficient to reconstruct every color channel at every 15 | pixel location. 16 | 17 | We convolve the image with these 5 kernels using stride=1 18 | and a one pixel replication padding. Finally, we gather 19 | the correct channel values for each pixel location. Todo so, 20 | we recognize that the Bayer pattern repeats horizontally and 21 | vertically every 2 pixels. Therefore, we define the correct 22 | index lookups for a 2x2 grid cell and then repeat to image 23 | dimensions. 24 | 25 | Note, in every 2x2 grid cell we have red, blue and two greens 26 | (G1,G2). The lookups for the two greens differ. 27 | ''' 28 | 29 | def __init__(self): 30 | super(Debayer3x3, self).__init__() 31 | 32 | self.kernels = torch.nn.Parameter( 33 | torch.tensor([ 34 | [0,0,0], 35 | [0,1,0], 36 | [0,0,0], 37 | 38 | [0, 0.25, 0], 39 | [0.25, 0, 0.25], 40 | [0, 0.25, 0], 41 | 42 | [0.25, 0, 0.25], 43 | [0, 0, 0], 44 | [0.25, 0, 0.25], 45 | 46 | [0, 0, 0], 47 | [0.5, 0, 0.5], 48 | [0, 0, 0], 49 | 50 | [0, 0.5, 0], 51 | [0, 0, 0], 52 | [0, 0.5, 0], 53 | ]).view(5,1,3,3), requires_grad=False 54 | ) 55 | 56 | 57 | self.index = torch.nn.Parameter( 58 | torch.tensor([ 59 | # dest channel r 60 | [0, 3], # pixel is R,G1 61 | [4, 2], # pixel is G2,B 62 | # dest channel g 63 | [1, 0], # pixel is R,G1 64 | [0, 1], # pixel is G2,B 65 | # dest channel b 66 | [2, 4], # pixel is R,G1 67 | [3, 0], # pixel is G2,B 68 | ]).view(1,3,2,2), requires_grad=False 69 | ) 70 | 71 | def forward(self, x): 72 | '''Debayer image. 73 | 74 | Parameters 75 | ---------- 76 | x : Bx1xHxW tensor 77 | Images to debayer 78 | 79 | Returns 80 | ------- 81 | rgb : Bx3xHxW tensor 82 | Color images in RGB channel order. 83 | ''' 84 | B,C,H,W = x.shape 85 | 86 | x = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') 87 | c = torch.nn.functional.conv2d(x, self.kernels, stride=1) 88 | rgb = torch.gather(c, 1, self.index.repeat(B,1,H//2,W//2)) 89 | return rgb 90 | 91 | class Debayer2x2(torch.nn.Module): 92 | '''Demosaicing of Bayer images using 2x2 convolutions. 93 | 94 | Requires BG-Bayer color filter array layout. That is, 95 | the image[1,1]='B', image[1,2]='G'. This corresponds 96 | to OpenCV naming conventions. 97 | ''' 98 | 99 | def __init__(self): 100 | super(Debayer2x2, self).__init__() 101 | 102 | self.kernels = torch.nn.Parameter( 103 | torch.tensor([ 104 | [1, 0], 105 | [0, 0], 106 | 107 | [0, 0.5], 108 | [0.5, 0], 109 | 110 | [0, 0], 111 | [0, 1], 112 | ]).view(3,1,2,2), requires_grad=False 113 | ) 114 | 115 | def forward(self, x): 116 | '''Debayer image. 117 | 118 | Parameters 119 | ---------- 120 | x : Bx1xHxW tensor 121 | Images to debayer 122 | 123 | Returns 124 | ------- 125 | rgb : Bx3xHxW tensor 126 | Color images in RGB channel order. 127 | ''' 128 | 129 | x = torch.nn.functional.conv2d(x, self.kernels, stride=2) 130 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 131 | return x 132 | 133 | class DebayerSplit(torch.nn.Module): 134 | '''Demosaicing of Bayer images using 3x3 green convolution and red,blue upsampling. 135 | 136 | Requires BG-Bayer color filter array layout. That is, 137 | the image[1,1]='B', image[1,2]='G'. This corresponds 138 | to OpenCV naming conventions. 139 | ''' 140 | def __init__(self): 141 | super().__init__() 142 | 143 | self.pad = torch.nn.ReflectionPad2d(1) 144 | self.kernel = torch.nn.Parameter( 145 | torch.tensor([ 146 | [0,1,0], 147 | [1,0,1], 148 | [0,1,0] 149 | ])[None, None] * 0.25) 150 | 151 | def forward(self, x): 152 | '''Debayer image. 153 | 154 | Parameters 155 | ---------- 156 | x : Bx1xHxW tensor 157 | Images to debayer 158 | 159 | Returns 160 | ------- 161 | rgb : Bx3xHxW tensor 162 | Color images in RGB channel order. 163 | ''' 164 | B,_,H,W = x.shape 165 | red = x[:, :, ::2, ::2] 166 | blue = x[:, :, 1::2, 1::2] 167 | 168 | green = torch.nn.functional.conv2d(self.pad(x), self.kernel) 169 | green[:, :, ::2, 1::2] = x[:, :, ::2, 1::2] 170 | green[:, :, 1::2, ::2] = x[:, :, 1::2, ::2] 171 | 172 | return torch.cat(( 173 | torch.nn.functional.interpolate(red, size=(H, W), mode='bilinear', align_corners=False), 174 | green, 175 | torch.nn.functional.interpolate(blue, size=(H, W), mode='bilinear', align_corners=False)), 176 | dim=1) -------------------------------------------------------------------------------- /utils/interp_methods.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | torch = None 7 | 8 | try: 9 | import numpy 10 | except ImportError: 11 | numpy = None 12 | 13 | if numpy is None and torch is None: 14 | raise ImportError("Must have either Numpy or PyTorch but both not found") 15 | 16 | 17 | def set_framework_dependencies(x): 18 | if type(x) is numpy.ndarray: 19 | to_dtype = lambda a: a 20 | fw = numpy 21 | else: 22 | to_dtype = lambda a: a.to(x.dtype) 23 | fw = torch 24 | eps = fw.finfo(fw.float32).eps 25 | return fw, to_dtype, eps 26 | 27 | 28 | def support_sz(sz): 29 | def wrapper(f): 30 | f.support_sz = sz 31 | return f 32 | return wrapper 33 | 34 | @support_sz(4) 35 | def cubic(x): 36 | fw, to_dtype, eps = set_framework_dependencies(x) 37 | absx = fw.abs(x) 38 | absx2 = absx ** 2 39 | absx3 = absx ** 3 40 | return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + 41 | (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * 42 | to_dtype((1. < absx) & (absx <= 2.))) 43 | 44 | @support_sz(4) 45 | def lanczos2(x): 46 | fw, to_dtype, eps = set_framework_dependencies(x) 47 | return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / 48 | ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) 49 | 50 | @support_sz(6) 51 | def lanczos3(x): 52 | fw, to_dtype, eps = set_framework_dependencies(x) 53 | return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / 54 | ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) 55 | 56 | @support_sz(2) 57 | def linear(x): 58 | fw, to_dtype, eps = set_framework_dependencies(x) 59 | return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * 60 | to_dtype((0 <= x) & (x <= 1))) 61 | 62 | @support_sz(1) 63 | def box(x): 64 | fw, to_dtype, eps = set_framework_dependencies(x) 65 | return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) 66 | -------------------------------------------------------------------------------- /utils/postprocessing_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import utils.data_format_utils as df_utils 4 | from data_processing.camera_pipeline import apply_gains, apply_ccm, apply_smoothstep, gamma_compression 5 | 6 | 7 | class SimplePostProcess: 8 | def __init__(self, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False): 9 | self.gains = gains 10 | self.ccm = ccm 11 | self.gamma = gamma 12 | self.smoothstep = smoothstep 13 | self.return_np = return_np 14 | 15 | def process(self, image, meta_info): 16 | return process_linear_image_rgb(image, meta_info, self.gains, self.ccm, self.gamma, 17 | self.smoothstep, self.return_np) 18 | 19 | 20 | def process_linear_image_rgb(image, meta_info, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False): 21 | if gains: 22 | image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain']) 23 | 24 | if ccm: 25 | image = apply_ccm(image, meta_info['cam2rgb']) 26 | 27 | if meta_info['gamma'] and gamma: 28 | image = gamma_compression(image) 29 | 30 | if meta_info['smoothstep'] and smoothstep: 31 | image = apply_smoothstep(image) 32 | 33 | image = image.clamp(0.0, 1.0) 34 | 35 | if return_np: 36 | image = df_utils.torch_to_npimage(image) 37 | return image 38 | 39 | 40 | class BurstSRPostProcess: 41 | def __init__(self, no_white_balance=False, gamma=True, smoothstep=True, return_np=False): 42 | self.no_white_balance = no_white_balance 43 | self.gamma = gamma 44 | self.smoothstep = smoothstep 45 | self.return_np = return_np 46 | 47 | def process(self, image, meta_info, external_norm_factor=None): 48 | return process_burstsr_image_rgb(image, meta_info, external_norm_factor=external_norm_factor, 49 | no_white_balance=self.no_white_balance, gamma=self.gamma, 50 | smoothstep=self.smoothstep, return_np=self.return_np) 51 | 52 | 53 | def process_burstsr_image_rgb(im, meta_info, return_np=False, external_norm_factor=None, gamma=True, smoothstep=True, 54 | no_white_balance=False): 55 | im = im * meta_info.get('norm_factor', 1.0) 56 | 57 | if not meta_info.get('black_level_subtracted', False): 58 | im = (im - torch.tensor(meta_info['black_level'])[[0, 1, -1]].view(3, 1, 1).to(im.device)) 59 | 60 | if not meta_info.get('while_balance_applied', False) and not no_white_balance: 61 | im = im * (meta_info['cam_wb'][[0, 1, -1]].view(3, 1, 1) / meta_info['cam_wb'][1]).to(im.device) 62 | 63 | im_out = im 64 | 65 | if external_norm_factor is None: 66 | im_out = im_out / im_out.max() 67 | else: 68 | im_out = im_out / external_norm_factor 69 | 70 | im_out = im_out.clamp(0.0, 1.0) 71 | 72 | if gamma: 73 | im_out = im_out ** (1.0 / 2.2) 74 | 75 | if smoothstep: 76 | # Smooth curve 77 | im_out = 3 * im_out ** 2 - 2 * im_out ** 3 78 | 79 | if return_np: 80 | im_out = im_out.permute(1, 2, 0).cpu().numpy() * 255.0 81 | im_out = im_out.astype(np.uint8) 82 | 83 | return im_out 84 | -------------------------------------------------------------------------------- /utils/spatial_color_alignment.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def gauss_1d(sz, sigma, center, end_pad=0, density=False): 7 | """ Returns a 1-D Gaussian """ 8 | k = torch.arange(-(sz-1)/2, (sz+1)/2 + end_pad).reshape(1, -1) 9 | gauss = torch.exp(-1.0/(2*sigma**2) * (k - center.reshape(-1, 1))**2) 10 | if density: 11 | gauss /= math.sqrt(2*math.pi) * sigma 12 | return gauss 13 | 14 | 15 | def gauss_2d(sz, sigma, center, end_pad=(0, 0), density=False): 16 | """ Returns a 2-D Gaussian """ 17 | if isinstance(sigma, (float, int)): 18 | sigma = (sigma, sigma) 19 | if isinstance(sz, int): 20 | sz = (sz, sz) 21 | 22 | if isinstance(center, (list, tuple)): 23 | center = torch.tensor(center).view(1, 2) 24 | 25 | return gauss_1d(sz[0], sigma[0], center[:, 0], end_pad[0], density).reshape(center.shape[0], 1, -1) * \ 26 | gauss_1d(sz[1], sigma[1], center[:, 1], end_pad[1], density).reshape(center.shape[0], -1, 1) 27 | 28 | 29 | def get_gaussian_kernel(sd): 30 | """ Returns a Gaussian kernel with standard deviation sd """ 31 | ksz = int(4 * sd + 1) 32 | assert ksz % 2 == 1 33 | K = gauss_2d(ksz, sd, (0.0, 0.0), density=True) 34 | K = K / K.sum() 35 | return K.unsqueeze(0), ksz 36 | 37 | 38 | def apply_kernel(im, ksz, gauss_kernel): 39 | shape = im.shape 40 | im = im.view(-1, 1, *im.shape[-2:]) 41 | 42 | pad = [ksz // 2, ksz // 2, ksz // 2, ksz // 2] 43 | im = F.pad(im, pad, mode='reflect') 44 | im_mean = F.conv2d(im, gauss_kernel).view(shape) 45 | return im_mean 46 | 47 | 48 | def match_colors(im_ref, im_q, im_test, ksz, gauss_kernel): 49 | """ Estimates a color transformation matrix between im_ref and im_q. Applies the estimated transformation to 50 | im_test 51 | """ 52 | gauss_kernel = gauss_kernel.to(im_ref.device) 53 | bi = 5 54 | 55 | # Apply Gaussian smoothing 56 | im_ref_mean = apply_kernel(im_ref, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous() 57 | im_q_mean = apply_kernel(im_q, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous() 58 | 59 | im_ref_mean_re = im_ref_mean.view(*im_ref_mean.shape[:2], -1) 60 | im_q_mean_re = im_q_mean.view(*im_q_mean.shape[:2], -1) 61 | 62 | # Estimate color transformation matrix by minimizing the least squares error 63 | c_mat_all = [] 64 | for ir, iq in zip(im_ref_mean_re, im_q_mean_re): 65 | c = torch.lstsq(ir.t(), iq.t()) 66 | c = c.solution[:3] 67 | c_mat_all.append(c) 68 | 69 | c_mat = torch.stack(c_mat_all, dim=0) 70 | im_q_mean_conv = torch.matmul(im_q_mean_re.permute(0, 2, 1), c_mat).permute(0, 2, 1) 71 | im_q_mean_conv = im_q_mean_conv.view(im_q_mean.shape) 72 | 73 | err = ((im_q_mean_conv - im_ref_mean) * 255.0).norm(dim=1) 74 | 75 | thresh = 20 76 | 77 | # If error is larger than a threshold, ignore these pixels 78 | valid = err < thresh 79 | 80 | pad = (im_q.shape[-1] - valid.shape[-1]) // 2 81 | pad = [pad, pad, pad, pad] 82 | valid = F.pad(valid, pad) 83 | 84 | upsample_factor = im_test.shape[-1] / valid.shape[-1] 85 | valid = F.interpolate(valid.unsqueeze(1).float(), scale_factor=upsample_factor, mode='bilinear', align_corners=False) 86 | valid = valid > 0.9 87 | 88 | # Apply the transformation to test image 89 | im_test_re = im_test.view(*im_test.shape[:2], -1) 90 | im_t_conv = torch.matmul(im_test_re.permute(0, 2, 1), c_mat).permute(0, 2, 1) 91 | im_t_conv = im_t_conv.view(im_test.shape) 92 | 93 | return im_t_conv, valid 94 | 95 | -------------------------------------------------------------------------------- /utils/stn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SpatialTransformer(nn.Module): 7 | """ 8 | [SpatialTransformer] represesents a spatial transformation block 9 | that uses the output from the UNet to preform an grid_sample 10 | https://pytorch.org/docs/stable/nn.functional.html#grid-sample 11 | """ 12 | def __init__(self, size, mode='bilinear'): 13 | """ 14 | Instiatiate the block 15 | :param size: size of input to the spatial transformer block 16 | :param mode: method of interpolation for grid_sampler 17 | """ 18 | super(OldSpatialTransformer, self).__init__() 19 | if isinstance(size, int): 20 | size = (size, size) 21 | # Create sampling grid 22 | vectors = [ torch.arange(0, s) for s in size ] 23 | grids = torch.meshgrid(vectors) 24 | grid = torch.stack(grids) # y, x, z 25 | grid = torch.unsqueeze(grid, 0) #add batch 26 | grid = grid.type(torch.FloatTensor) 27 | self.register_buffer('grid', grid) 28 | 29 | self.mode = mode 30 | 31 | def forward(self, src, flow): 32 | """ 33 | Push the src and flow through the spatial transform block 34 | :param src: the original moving image 35 | :param flow: the output from the U-Net 36 | """ 37 | new_locs = self.grid + flow 38 | 39 | shape = flow.shape[2:] 40 | 41 | # Need to normalize grid values to [-1, 1] for resampler 42 | for i in range(len(shape)): 43 | new_locs[:,i,...] = 2*(new_locs[:,i,...]/(shape[i]-1) - 0.5) 44 | 45 | if len(shape) == 2: 46 | new_locs = new_locs.permute(0, 2, 3, 1) 47 | new_locs = new_locs[..., [1,0]] 48 | elif len(shape) == 3: 49 | new_locs = new_locs.permute(0, 2, 3, 4, 1) 50 | new_locs = new_locs[..., [2,1,0]] 51 | 52 | return F.grid_sample(src, new_locs, mode=self.mode, align_corners=True) 53 | -------------------------------------------------------------------------------- /utils/warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def warp(feat, flow, mode='bilinear', padding_mode='zeros'): 7 | """ 8 | warp an image/tensor (im2) back to im1, according to the optical flow im1 --> im2 9 | 10 | input flow must be in format (x, y) at every pixel 11 | feat: [B, C, H, W] (im2) 12 | flow: [B, 2, H, W] flow (x, y) 13 | 14 | """ 15 | B, C, H, W = feat.size() 16 | # print(feat.device, flow.device) 17 | 18 | # mesh grid 19 | rowv, colv = torch.meshgrid([torch.arange(0.5, H + 0.5), torch.arange(0.5, W + 0.5)]) 20 | grid = torch.stack((colv, rowv), dim=0).unsqueeze(0).float().to(flow.device) 21 | # print(grid.device, flow.device, feat.device) 22 | # grid = grid.cuda() 23 | grid = grid + flow 24 | 25 | # scale grid to [-1,1] 26 | grid_norm_c = 2.0 * grid[:, 0] / W - 1.0 27 | grid_norm_r = 2.0 * grid[:, 1] / H - 1.0 28 | 29 | grid_norm = torch.stack((grid_norm_c, grid_norm_r), dim=1).to(flow.device) 30 | 31 | grid_norm = grid_norm.permute(0, 2, 3, 1) 32 | 33 | output = F.grid_sample(feat, grid_norm, mode=mode, align_corners=False, padding_mode=padding_mode) 34 | 35 | return output 36 | --------------------------------------------------------------------------------