├── .gitignore
├── README.md
├── data_processing
├── __init__.py
├── camera_pipeline.py
└── synthetic_burst_generation.py
├── datasets
├── __init__.py
├── burstsr_dataset.py
├── burstsr_test_dataset.py
├── synthetic_burst_test_set.py
├── synthetic_burst_train_set.py
├── synthetic_burst_val_set.py
└── zurich_raw2rgb_dataset.py
├── figs
└── ts.png
├── loss
├── Charbonnier.py
├── __init__.py
├── adversarial.py
├── discriminator.py
├── filter.py
├── hist_entropy.py
├── mssim.py
└── vgg.py
├── main.py
├── model
├── DCNv2
│ ├── DCNv2.egg-info
│ │ ├── PKG-INFO
│ │ ├── SOURCES.txt
│ │ ├── dependency_links.txt
│ │ └── top_level.txt
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── __pycache__
│ │ └── dcn_v2.cpython-37.pyc
│ ├── build
│ │ ├── lib.linux-x86_64-3.7
│ │ │ └── _ext.cpython-37m-x86_64-linux-gnu.so
│ │ └── temp.linux-x86_64-3.7
│ │ │ └── data
│ │ │ └── work
│ │ │ └── pylibs
│ │ │ └── DCNv2-pytorch_1.6
│ │ │ └── src
│ │ │ ├── cpu
│ │ │ ├── dcn_v2_cpu.o
│ │ │ ├── dcn_v2_im2col_cpu.o
│ │ │ └── dcn_v2_psroi_pooling_cpu.o
│ │ │ ├── cuda
│ │ │ ├── dcn_v2_cuda.o
│ │ │ ├── dcn_v2_im2col_cuda.o
│ │ │ └── dcn_v2_psroi_pooling_cuda.o
│ │ │ └── vision.o
│ ├── dcn_v2.py
│ ├── dist
│ │ └── DCNv2-0.1-py3.7-linux-x86_64.egg
│ ├── files.txt
│ ├── make.sh
│ ├── setup.py
│ ├── src
│ │ ├── cpu
│ │ │ ├── dcn_v2_cpu.cpp
│ │ │ ├── dcn_v2_im2col_cpu.cpp
│ │ │ ├── dcn_v2_im2col_cpu.h
│ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp
│ │ │ └── vision.h
│ │ ├── cuda
│ │ │ ├── dcn_v2_cuda.cu
│ │ │ ├── dcn_v2_im2col_cuda.cu
│ │ │ ├── dcn_v2_im2col_cuda.h
│ │ │ ├── dcn_v2_psroi_pooling_cuda.cu
│ │ │ └── vision.h
│ │ ├── dcn_v2.h
│ │ └── vision.cpp
│ └── test.py
├── __init__.py
├── arch_util.py
├── common.py
├── ebsr.py
├── non_local
│ ├── network.py
│ ├── non_local_concatenation.py
│ ├── non_local_cross_dot_product.py
│ ├── non_local_dot_product.py
│ ├── non_local_embedded_gaussian.py
│ └── non_local_gaussian.py
└── utils
│ ├── interp_methods.py
│ ├── psconv.py
│ └── resize_right.py
├── option.py
├── pwcnet
├── LICENSE
├── README.md
├── __init__.py
├── comparison
│ ├── comparison.gif
│ ├── comparison.py
│ ├── official - caffe.png
│ └── this - pytorch.png
├── correlation
│ ├── README.md
│ ├── __pycache__
│ │ └── correlation.cpython-37.pyc
│ └── correlation.py
├── download.bash
├── images
│ ├── README.md
│ ├── first.png
│ └── second.png
├── out.flo
├── pwcnet.py
├── requirements.txt
└── run.py
├── requirements.txt
├── scripts
├── __init__.py
├── cal_mean_std.py
├── demo.sh
├── download_burstsr_dataset.py
├── evaluate.sh
├── evaluate_burstsr_val.py
├── save_results_synburst_val.py
├── test_burstsr_dataset.py
└── test_synthetic_bursts.py
├── test.py
├── test_real.py
├── trainer.py
├── utility.py
└── utils
├── __init__.py
├── data_format_utils.py
├── debayer.py
├── interp_methods.py
├── metrics.py
├── postprocessing_functions.py
├── resize_right.py
├── spatial_color_alignment.py
├── stn.py
└── warp.py
/.gitignore:
--------------------------------------------------------------------------------
1 | demo.sh
2 | /checkpoints
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EBSR: Feature Enhanced Burst Super-Resolution With Deformable Alignment (CVPRW 2021)
2 |
3 |
4 | ### Update !!!
5 | - **2022.04.22** 🎉🎉🎉 We won the 1st place in NTIRE 2022 BurstSR Challenge again [[Paper]](https://arxiv.org/abs/2204.08332)[[Code]](https://github.com/Algolzw/BSRT).
6 | - **2022.01.22** We updated the code to support real track testing and provided the model weights [here](https://drive.google.com/file/d/1Zz21YwNtiKZCjerrZsdvcWyubqTJBwaD/view?usp=sharing)
7 | - **2021** Now we support 1 GPU training and provide the pretrained model [here](https://drive.google.com/file/d/1_WA2chhITIsCj6qImcEM2lD6c-iJsRpy/view?usp=sharing).
8 |
9 |
10 |
11 |
12 |

13 |
14 |
15 | This repository is an official PyTorch implementation of the paper **"EBSR: Feature Enhanced Burst Super-Resolution With Deformable Alignment"** from CVPRW 2021, 1st NTIRE21 Burst SR in real track (2nd in synthetic track).
16 |
17 | ## Dependencies
18 | - OS: Ubuntu 18.04
19 | - Python: Python 3.7
20 | - nvidia :
21 | - cuda: 10.1
22 | - cudnn: 7.6.1
23 | - Other reference requirements
24 |
25 | ## Quick Start
26 | 1.Create a conda virtual environment and activate it
27 | ```python3
28 | conda create -n pytorch_1.6 python=3.7
29 | source activate pytorch_1.6
30 | ```
31 | 2.Install PyTorch and torchvision following the official instructions
32 | ```python3
33 | conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch
34 | ```
35 | 3.Install build requirements
36 | ```python3
37 | pip3 install -r requirements.txt
38 | ```
39 | 4.Install apex to use DistributedDataParallel following the [Nvidia apex](https://github.com/NVIDIA/apex) (optional)
40 | ```python3
41 | git clone https://github.com/NVIDIA/apex
42 | cd apex
43 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
44 | ```
45 | 5.Install DCN
46 | ```python3
47 | cd DCNv2-pytorch_1.6
48 | python3 setup.py build develop # build
49 | python3 test.py # run examples and check
50 | ```
51 | ## Training
52 | ```python3
53 | # Modify the root path of training dataset and model etc.
54 | # The number of GPUs should be more than 1
55 | python main.py --n_GPUs 4 --lr 0.0002 --decay 200-400 --save ebsr --model EBSR --fp16 --lrcn --non_local --n_feats 128 --n_resblocks 8 --n_resgroups 5 --batch_size 16 --burst_size 14 --patch_size 256 --scale 4 --loss 1*L1
56 | ```
57 | ## Test
58 | ```python3
59 | # Modify the path of test dataset and the path of the trained model
60 | python test.py --root /data/dataset/ntire21/burstsr/synthetic/syn_burst_val --model EBSR --lrcn --non_local --n_feats 128 --n_resblocks 8 --n_resgroups 5 --burst_size 14 --scale 4 --pre_train ./checkpoints/EBSRbest_epoch.pth
61 | ```
62 | or test on the validation dataset:
63 | ```python3
64 | python main.py --n_GPUs 1 --test_only --model EBSR --lrcn --non_local --n_feats 128 --n_resblocks 8 --n_resgroups 5 --burst_size 14 --scale 4 --pre_train ./checkpoints/EBSRbest_epoch.pth
65 | ```
66 | ### Real track evaluation
67 | You may need to download pretrained PWC model to the pwcnet directory ([here](https://drive.google.com/file/d/1dD6vB9QN3qwmOBi3AGKzJbbSojwDDlgV/view?usp=sharing)).
68 |
69 | ```
70 | python test_real.py --n_GPUs 1 --model EBSR --lrcn --non_local --n_feats 128 --n_resblocks 8 --n_resgroups 5 --burst_size 14 --scale 4 --pre_train ./checkpoints/BBSR_realbest_epoch.pth --root burstsr_validation_dataset...
71 |
72 | ```
73 |
74 | ## Citations
75 | If EBSR helps your research or work, please consider citing EBSR.
76 | The following is a BibTeX reference.
77 |
78 | ```
79 | @InProceedings{Luo_2021_CVPR,
80 | author = {Luo, Ziwei and Yu, Lei and Mo, Xuan and Li, Youwei and Jia, Lanpeng and Fan, Haoqiang and Sun, Jian and Liu, Shuaicheng},
81 | title = {EBSR: Feature Enhanced Burst Super-Resolution With Deformable Alignment},
82 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
83 | month = {June},
84 | year = {2021},
85 | pages = {471-478}
86 | }
87 | ```
88 |
89 | ## Contact
90 | email: [ziwei.ro@gmail.com, yl_yjsy@163.com]
91 |
--------------------------------------------------------------------------------
/data_processing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/data_processing/__init__.py
--------------------------------------------------------------------------------
/data_processing/camera_pipeline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import math
4 | import cv2 as cv
5 | import numpy as np
6 | import utils.data_format_utils as df_utils
7 | """ Based on http://timothybrooks.com/tech/unprocessing
8 | Functions for forward and inverse camera pipeline. All functions input a torch float tensor of shape (c, h, w).
9 | Additionally, some also support batch operations, i.e. inputs of shape (b, c, h, w)
10 | """
11 |
12 |
13 | def random_ccm():
14 | """Generates random RGB -> Camera color correction matrices."""
15 | # Takes a random convex combination of XYZ -> Camera CCMs.
16 | xyz2cams = [[[1.0234, -0.2969, -0.2266],
17 | [-0.5625, 1.6328, -0.0469],
18 | [-0.0703, 0.2188, 0.6406]],
19 | [[0.4913, -0.0541, -0.0202],
20 | [-0.613, 1.3513, 0.2906],
21 | [-0.1564, 0.2151, 0.7183]],
22 | [[0.838, -0.263, -0.0639],
23 | [-0.2887, 1.0725, 0.2496],
24 | [-0.0627, 0.1427, 0.5438]],
25 | [[0.6596, -0.2079, -0.0562],
26 | [-0.4782, 1.3016, 0.1933],
27 | [-0.097, 0.1581, 0.5181]]]
28 |
29 | num_ccms = len(xyz2cams)
30 | xyz2cams = torch.tensor(xyz2cams)
31 |
32 | weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(0.0, 1.0)
33 | weights_sum = weights.sum()
34 | xyz2cam = (xyz2cams * weights).sum(dim=0) / weights_sum
35 |
36 | # Multiplies with RGB -> XYZ to get RGB -> Camera CCM.
37 | rgb2xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
38 | [0.2126729, 0.7151522, 0.0721750],
39 | [0.0193339, 0.1191920, 0.9503041]])
40 | rgb2cam = torch.mm(xyz2cam, rgb2xyz)
41 |
42 | # Normalizes each row.
43 | rgb2cam = rgb2cam / rgb2cam.sum(dim=-1, keepdims=True)
44 | return rgb2cam
45 |
46 |
47 | def random_gains():
48 | """Generates random gains for brightening and white balance."""
49 | # RGB gain represents brightening.
50 | rgb_gain = 1.0 / random.gauss(mu=0.8, sigma=0.1)
51 |
52 | # Red and blue gains represent white balance.
53 | red_gain = random.uniform(1.9, 2.4)
54 | blue_gain = random.uniform(1.5, 1.9)
55 | return rgb_gain, red_gain, blue_gain
56 |
57 |
58 | def apply_smoothstep(image):
59 | """Apply global tone mapping curve."""
60 | image_out = 3 * image**2 - 2 * image**3
61 | return image_out
62 |
63 |
64 | def invert_smoothstep(image):
65 | """Approximately inverts a global tone mapping curve."""
66 | image = image.clamp(0.0, 1.0)
67 | return 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0)
68 |
69 |
70 | def gamma_expansion(image):
71 | """Converts from gamma to linear space."""
72 | # Clamps to prevent numerical instability of gradients near zero.
73 | return image.clamp(1e-8) ** 2.2
74 |
75 |
76 | def gamma_compression(image):
77 | """Converts from linear to gammaspace."""
78 | # Clamps to prevent numerical instability of gradients near zero.
79 | return image.clamp(1e-8) ** (1.0 / 2.2)
80 |
81 |
82 | def apply_ccm(image, ccm):
83 | """Applies a color correction matrix."""
84 | assert image.dim() == 3 and image.shape[0] == 3
85 |
86 | shape = image.shape
87 | image = image.view(3, -1)
88 | ccm = ccm.to(image.device).type_as(image)
89 |
90 | image = torch.mm(ccm, image)
91 |
92 | return image.view(shape)
93 |
94 |
95 | def apply_gains(image, rgb_gain, red_gain, blue_gain):
96 | """Inverts gains while safely handling saturated pixels."""
97 | assert image.dim() == 3 and image.shape[0] in [3, 4]
98 |
99 | if image.shape[0] == 3:
100 | gains = torch.tensor([red_gain, 1.0, blue_gain]) * rgb_gain
101 | else:
102 | gains = torch.tensor([red_gain, 1.0, 1.0, blue_gain]) * rgb_gain
103 | gains = gains.view(-1, 1, 1)
104 | gains = gains.to(image.device).type_as(image)
105 |
106 | return (image * gains).clamp(0.0, 1.0)
107 |
108 |
109 | def safe_invert_gains(image, rgb_gain, red_gain, blue_gain):
110 | """Inverts gains while safely handling saturated pixels."""
111 | assert image.dim() == 3 and image.shape[0] == 3
112 |
113 | gains = torch.tensor([1.0 / red_gain, 1.0, 1.0 / blue_gain]) / rgb_gain
114 | gains = gains.view(-1, 1, 1)
115 |
116 | # Prevents dimming of saturated pixels by smoothly masking gains near white.
117 | gray = image.mean(dim=0, keepdims=True)
118 | inflection = 0.9
119 | mask = ((gray - inflection).clamp(0.0) / (1.0 - inflection)) ** 2.0
120 |
121 | safe_gains = torch.max(mask + (1.0 - mask) * gains, gains)
122 | return image * safe_gains
123 |
124 |
125 | def mosaic(image, mode='rggb'):
126 | """Extracts RGGB Bayer planes from an RGB image."""
127 | shape = image.shape
128 | if image.dim() == 3:
129 | image = image.unsqueeze(0)
130 |
131 | if mode == 'rggb':
132 | red = image[:, 0, 0::2, 0::2]
133 | green_red = image[:, 1, 0::2, 1::2]
134 | green_blue = image[:, 1, 1::2, 0::2]
135 | blue = image[:, 2, 1::2, 1::2]
136 | image = torch.stack((red, green_red, green_blue, blue), dim=1)
137 | elif mode == 'grbg':
138 | green_red = image[:, 1, 0::2, 0::2]
139 | red = image[:, 0, 0::2, 1::2]
140 | blue = image[:, 2, 0::2, 1::2]
141 | green_blue = image[:, 1, 1::2, 1::2]
142 |
143 | image = torch.stack((green_red, red, blue, green_blue), dim=1)
144 |
145 | if len(shape) == 3:
146 | return image.view((4, shape[-2] // 2, shape[-1] // 2))
147 | else:
148 | return image.view((-1, 4, shape[-2] // 2, shape[-1] // 2))
149 |
150 |
151 | def demosaic(image):
152 | assert isinstance(image, torch.Tensor)
153 | image = image.clamp(0.0, 1.0) * 255
154 |
155 | if image.dim() == 4:
156 | num_images = image.dim()
157 | batch_input = True
158 | else:
159 | num_images = 1
160 | batch_input = False
161 | image = image.unsqueeze(0)
162 |
163 | # Generate single channel input for opencv
164 | im_sc = torch.zeros((num_images, image.shape[-2] * 2, image.shape[-1] * 2, 1))
165 | im_sc[:, ::2, ::2, 0] = image[:, 0, :, :]
166 | im_sc[:, ::2, 1::2, 0] = image[:, 1, :, :]
167 | im_sc[:, 1::2, ::2, 0] = image[:, 2, :, :]
168 | im_sc[:, 1::2, 1::2, 0] = image[:, 3, :, :]
169 |
170 | im_sc = im_sc.numpy().astype(np.uint8)
171 |
172 | out = []
173 |
174 | for im in im_sc:
175 | # cv.imwrite('frames/tmp.png', im)
176 | im_dem_np = cv.cvtColor(im, cv.COLOR_BAYER_BG2RGB)#_VNG)
177 |
178 | # Convert to torch image
179 | im_t = df_utils.npimage_to_torch(im_dem_np, input_bgr=False)
180 | out.append(im_t)
181 |
182 | if batch_input:
183 | return torch.stack(out, dim=0)
184 | else:
185 | return out[0]
186 |
187 |
188 | def random_noise_levels():
189 | """Generates random noise levels from a log-log linear distribution."""
190 | log_min_shot_noise = math.log(0.0001)
191 | log_max_shot_noise = math.log(0.012)
192 | log_shot_noise = random.uniform(log_min_shot_noise, log_max_shot_noise)
193 | shot_noise = math.exp(log_shot_noise)
194 |
195 | line = lambda x: 2.18 * x + 1.20
196 | log_read_noise = line(log_shot_noise) + random.gauss(mu=0.0, sigma=0.26)
197 | read_noise = math.exp(log_read_noise)
198 | return shot_noise, read_noise
199 |
200 |
201 | def add_noise(image, shot_noise=0.01, read_noise=0.0005):
202 | """Adds random shot (proportional to image) and read (independent) noise."""
203 | variance = image * shot_noise + read_noise
204 | noise = torch.FloatTensor(image.shape).normal_().to(image.device)*variance.sqrt()
205 | return image + noise
206 |
207 |
208 | def process_linear_image_rgb(image, meta_info, return_np=False):
209 | image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])
210 | image = apply_ccm(image, meta_info['cam2rgb'])
211 |
212 | if meta_info['gamma']:
213 | image = gamma_compression(image)
214 |
215 | if meta_info['smoothstep']:
216 | image = apply_smoothstep(image)
217 |
218 | image = image.clamp(0.0, 1.0)
219 |
220 | if return_np:
221 | image = df_utils.torch_to_npimage(image)
222 | return image
223 |
224 |
225 | def process_linear_image_raw(image, meta_info):
226 | image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])
227 | image = demosaic(image)
228 | image = apply_ccm(image, meta_info['cam2rgb'])
229 |
230 | if meta_info['gamma']:
231 | image = gamma_compression(image)
232 |
233 | if meta_info['smoothstep']:
234 | image = apply_smoothstep(image)
235 | return image.clamp(0.0, 1.0)
236 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/burstsr_test_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 | import random
5 | from .burstsr_dataset import SamsungRAWImage, flatten_raw_image, pack_raw_image
6 |
7 |
8 | class BurstSRDataset(torch.utils.data.Dataset):
9 | """ Real-world burst super-resolution dataset. """
10 | def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='test'):
11 | """
12 | args:
13 | root : path of the root directory
14 | burst_size : Burst size. Maximum allowed burst size is 14.
15 | crop_sz: Size of the extracted crop. Maximum allowed crop size is 80
16 | center_crop: Whether to extract a random crop, or a centered crop.
17 | random_flip: Whether to apply random horizontal and vertical flip
18 | split: Can be 'train' or 'val'
19 | """
20 | assert burst_size <= 14, 'burst_sz must be less than or equal to 14'
21 | assert crop_sz <= 80, 'crop_sz must be less than or equal to 80'
22 | assert split in ['test']
23 |
24 | root = root + '/' + split
25 | super().__init__()
26 |
27 | self.burst_size = burst_size
28 | self.crop_sz = crop_sz
29 | self.split = split
30 | self.center_crop = center_crop
31 | self.random_flip = random_flip
32 |
33 | self.root = root
34 |
35 | self.substract_black_level = True
36 | self.white_balance = False
37 |
38 | self.burst_list = self._get_burst_list()
39 |
40 | def _get_burst_list(self):
41 | burst_list = sorted(os.listdir('{}'.format(self.root)))
42 |
43 | return burst_list
44 |
45 | def get_burst_info(self, burst_id):
46 | burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]}
47 | return burst_info
48 |
49 | def _get_raw_image(self, burst_id, im_id):
50 | raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id))
51 | return raw_image
52 |
53 | def get_burst(self, burst_id, im_ids, info=None):
54 | frames = [self._get_raw_image(burst_id, i) for i in im_ids]
55 |
56 | if info is None:
57 | info = self.get_burst_info(burst_id)
58 |
59 | return frames, info
60 |
61 | def _sample_images(self):
62 | burst_size = 14
63 |
64 | ids = random.sample(range(1, burst_size), k=self.burst_size - 1)
65 | ids = [0, ] + ids
66 | return ids
67 |
68 | def __len__(self):
69 | return len(self.burst_list)
70 |
71 | def __getitem__(self, index):
72 | # Sample the images in the burst, in case a burst_size < 14 is used.
73 | im_ids = self._sample_images()
74 |
75 | # Read the burst images along with HR ground truth
76 | frames, meta_info = self.get_burst(index, im_ids)
77 |
78 | # Extract crop if needed
79 | if frames[0].shape()[-1] != self.crop_sz:
80 | if getattr(self, 'center_crop', False):
81 | r1 = (frames[0].shape()[-2] - self.crop_sz) // 2
82 | c1 = (frames[0].shape()[-1] - self.crop_sz) // 2
83 | else:
84 | r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz)
85 | c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz)
86 | r2 = r1 + self.crop_sz
87 | c2 = c1 + self.crop_sz
88 |
89 | frames = [im.get_crop(r1, r2, c1, c2) for im in frames]
90 |
91 | # Load the RAW image data
92 | burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level,
93 | white_balance=self.white_balance) for im in frames]
94 |
95 | if self.random_flip:
96 | burst_image_data = [flatten_raw_image(im) for im in burst_image_data]
97 |
98 | pad = [0, 0, 0, 0]
99 | if random.random() > 0.5:
100 | burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data]
101 | pad[1] = 1
102 |
103 | if random.random() > 0.5:
104 | burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data]
105 | pad[3] = 1
106 |
107 | burst_image_data = [pack_raw_image(im) for im in burst_image_data]
108 | burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data]
109 |
110 | burst_image_meta_info = frames[0].get_all_meta_data()
111 |
112 | burst_image_meta_info['black_level_subtracted'] = self.substract_black_level
113 | burst_image_meta_info['while_balance_applied'] = self.white_balance
114 | burst_image_meta_info['norm_factor'] = frames[0].norm_factor
115 |
116 | burst = torch.stack(burst_image_data, dim=0)
117 |
118 | burst_exposure = frames[0].get_exposure_time()
119 |
120 | burst_f_number = frames[0].get_f_number()
121 |
122 | burst_iso = frames[0].get_iso()
123 |
124 | burst_image_meta_info['exposure'] = burst_exposure
125 | burst_image_meta_info['f_number'] = burst_f_number
126 | burst_image_meta_info['iso'] = burst_iso
127 |
128 | burst = burst.float()
129 |
130 | meta_info_burst = burst_image_meta_info
131 |
132 | for k, v in meta_info_burst.items():
133 | if isinstance(v, (list, tuple)):
134 | meta_info_burst[k] = torch.tensor(v)
135 |
136 | return burst, meta_info_burst
--------------------------------------------------------------------------------
/datasets/synthetic_burst_test_set.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import numpy as np
4 |
5 |
6 | class SyntheticBurstVal(torch.utils.data.Dataset):
7 | """ Synthetic burst validation set. The validation burst have been generated using the same synthetic pipeline as
8 | employed in SyntheticBurst dataset.
9 | """
10 | def __init__(self, root):
11 | self.root = root
12 | self.burst_list = list(range(500))
13 | self.burst_size = 14
14 |
15 | def __len__(self):
16 | return len(self.burst_list)
17 |
18 | def _read_burst_image(self, index, image_id):
19 | im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)
20 | im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)
21 | return im_t
22 |
23 | def __getitem__(self, index):
24 | """ Generates a synthetic burst
25 | args:
26 | index: Index of the burst
27 |
28 | returns:
29 | burst: LR RAW burst, a torch tensor of shape
30 | [14, 4, 48, 48]
31 | The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
32 | seq_name: Name of the burst sequence
33 | """
34 | burst_name = '{:04d}'.format(index)
35 | burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]
36 | burst = torch.stack(burst, 0)
37 |
38 | return burst, burst_name
--------------------------------------------------------------------------------
/datasets/synthetic_burst_train_set.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | from data_processing.synthetic_burst_generation import rgb2rawburst, random_crop #syn_burst_utils
5 | import torchvision.transforms as tfm
6 |
7 |
8 | class SyntheticBurst(torch.utils.data.Dataset):
9 | """ Synthetic burst dataset for joint denoising, demosaicking, and super-resolution. RAW Burst sequences are
10 | synthetically generated on the fly as follows. First, a single image is loaded from the base_dataset. The sampled
11 | image is converted to linear sensor space using the inverse camera pipeline employed in [1]. A burst
12 | sequence is then generated by adding random translations and rotations to the converted image. The generated burst
13 | is then converted is then mosaicked, and corrupted by random noise to obtain the RAW burst.
14 |
15 | [1] Unprocessing Images for Learned Raw Denoising, Brooks, Tim and Mildenhall, Ben and Xue, Tianfan and Chen,
16 | Jiawen and Sharlet, Dillon and Barron, Jonathan T, CVPR 2019
17 | """
18 | def __init__(self, base_dataset, burst_size=8, crop_sz=384, transform=tfm.ToTensor()):
19 | self.base_dataset = base_dataset
20 |
21 | self.burst_size = burst_size
22 | self.crop_sz = crop_sz
23 | self.transform = transform
24 |
25 | self.downsample_factor = 4
26 | self.burst_transformation_params = {'max_translation': 24.0,
27 | 'max_rotation': 1.0,
28 | 'max_shear': 0.0,
29 | 'max_scale': 0.0,
30 | 'border_crop': 24}
31 |
32 | self.image_processing_params = {'random_ccm': True, 'random_gains': True, 'smoothstep': True,
33 | 'gamma': True,
34 | 'add_noise': True}
35 | self.interpolation_type = 'bilinear'
36 |
37 | def __len__(self):
38 | return len(self.base_dataset)
39 |
40 | def __getitem__(self, index):
41 | """ Generates a synthetic burst
42 | args:
43 | index: Index of the image in the base_dataset used to generate the burst
44 |
45 | returns:
46 | burst: Generated LR RAW burst, a torch tensor of shape
47 | [burst_size, 4, self.crop_sz / (2*self.downsample_factor), self.crop_sz / (2*self.downsample_factor)]
48 | The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
49 | The extra factor 2 in the denominator (2*self.downsample_factor) corresponds to the mosaicking
50 | operation.
51 |
52 | frame_gt: The HR RGB ground truth in the linear sensor space, a torch tensor of shape
53 | [3, self.crop_sz, self.crop_sz]
54 |
55 | flow_vectors: The ground truth flow vectors between a burst image and the base image (i.e. the first image in the burst).
56 | The flow_vectors can be used to warp the burst images to the base frame, using the 'warp'
57 | function in utils.warp package.
58 | flow_vectors is torch tensor of shape
59 | [burst_size, 2, self.crop_sz / self.downsample_factor, self.crop_sz / self.downsample_factor].
60 | Note that the flow_vectors are in the LR RGB space, before mosaicking. Hence it has twice
61 | the number of rows and columns, compared to the output burst.
62 |
63 | NOTE: The flow_vectors are only available during training for the purpose of using any
64 | auxiliary losses if needed. The flow_vectors will NOT be provided for the bursts in the
65 | test set
66 |
67 | meta_info: A dictionary containing the parameters used to generate the synthetic burst.
68 | """
69 | frame = self.base_dataset[index]
70 |
71 | # Augmentation, e.g. convert to tensor
72 | if self.transform is not None:
73 | # frame = Image.fromarray(frame)
74 | frame = self.transform(frame)
75 |
76 | # Extract a random crop from the image
77 | crop_sz = self.crop_sz + 2 * self.burst_transformation_params.get('border_crop', 0)
78 | frame_crop = random_crop(frame, crop_sz)
79 |
80 | # Generate RAW burst
81 | burst, frame_gt, burst_rgb, flow_vectors, meta_info = rgb2rawburst(frame_crop,
82 | self.burst_size,
83 | self.downsample_factor,
84 | burst_transformation_params=self.burst_transformation_params,
85 | image_processing_params=self.image_processing_params,
86 | interpolation_type=self.interpolation_type
87 | )
88 |
89 | if self.burst_transformation_params.get('border_crop') is not None:
90 | border_crop = self.burst_transformation_params.get('border_crop')
91 | frame_gt = frame_gt[:, border_crop:-border_crop, border_crop:-border_crop]
92 |
93 | return burst, frame_gt, flow_vectors, meta_info
94 |
--------------------------------------------------------------------------------
/datasets/synthetic_burst_val_set.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import numpy as np
4 |
5 |
6 | class SyntheticBurstVal(torch.utils.data.Dataset):
7 | """ Synthetic burst validation set. The validation burst have been generated using the same synthetic pipeline as
8 | employed in SyntheticBurst dataset.
9 | """
10 | def __init__(self, root):
11 | self.root = root
12 | self.burst_list = list(range(100))
13 | self.burst_size = 14
14 |
15 | def __len__(self):
16 | return len(self.burst_list)
17 |
18 | def _read_burst_image(self, index, image_id):
19 | im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)
20 | im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)
21 | return im_t
22 |
23 | def __getitem__(self, index):
24 | """ Generates a synthetic burst
25 | args:
26 | index: Index of the burst
27 |
28 | returns:
29 | burst: LR RAW burst, a torch tensor of shape
30 | [14, 4, 48, 48]
31 | The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
32 | seq_name: Name of the burst sequence
33 | """
34 | burst_name = '{:04d}'.format(index)
35 | burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]
36 | burst = torch.stack(burst, 0)
37 |
38 | return burst, burst_name
39 |
--------------------------------------------------------------------------------
/datasets/zurich_raw2rgb_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from cv2 import imread
4 |
5 |
6 | class ZurichRAW2RGB(torch.utils.data.Dataset):
7 | """ Canon RGB images from the "Zurich RAW to RGB mapping" dataset. You can download the full
8 | dataset (22 GB) from http://people.ee.ethz.ch/~ihnatova/pynet.html#dataset. Alternatively, you can only download the
9 | Canon RGB images (5.5 GB) from https://data.vision.ee.ethz.ch/bhatg/zurich-raw-to-rgb.zip
10 | """
11 | def __init__(self, root, split='train'):
12 | super().__init__()
13 |
14 | if split in ['train', 'test']:
15 | self.img_pth = os.path.join(root, split, 'canon')
16 | else:
17 | raise Exception('Unknown split {}'.format(split))
18 |
19 | self.image_list = self._get_image_list(split)
20 |
21 | def _get_image_list(self, split):
22 | if split == 'train':
23 | image_list = ['{:d}.jpg'.format(i) for i in range(46839)]
24 | elif split == 'test':
25 | image_list = ['{:d}.jpg'.format(i) for i in range(1204)]
26 | else:
27 | raise Exception
28 |
29 | return image_list
30 |
31 | def _get_image(self, im_id):
32 | path = os.path.join(self.img_pth, self.image_list[im_id])
33 | img = imread(path)
34 | return img
35 |
36 | def get_image(self, im_id):
37 | frame = self._get_image(im_id)
38 |
39 | return frame
40 |
41 | def __len__(self):
42 | return len(self.image_list)
43 |
44 | def __getitem__(self, index):
45 | frame = self._get_image(index)
46 |
47 | return frame
48 |
--------------------------------------------------------------------------------
/figs/ts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/figs/ts.png
--------------------------------------------------------------------------------
/loss/Charbonnier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class CharbonnierLoss(nn.Module):
6 | """L1 charbonnier loss."""
7 |
8 | def __init__(self, epsilon=1e-3, reduce=True):
9 | super(CharbonnierLoss, self).__init__()
10 | self.eps = epsilon * epsilon
11 | self.reduce = reduce
12 |
13 | def forward(self, X, Y):
14 | diff = torch.add(X, -Y)
15 | error = torch.sqrt(diff * diff + self.eps)
16 | if self.reduce:
17 | loss = torch.mean(error)
18 | else:
19 | loss = error
20 | return loss
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import matplotlib
5 | matplotlib.use('Agg')
6 | import matplotlib.pyplot as plt
7 |
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | class Loss(nn.modules.loss._Loss):
15 | def __init__(self, args, ckp):
16 | super(Loss, self).__init__()
17 | if args.local_rank == 0:
18 | print('Preparing loss function:')
19 |
20 | self.n_GPUs = args.n_GPUs
21 | self.loss = []
22 | self.loss_module = nn.ModuleList()
23 | for loss in args.loss.split('+'):
24 | weight, loss_type = loss.split('*')
25 | if loss_type == 'MSE':
26 | loss_function = nn.MSELoss()
27 | elif loss_type == 'L1':
28 | loss_function = nn.L1Loss()
29 | elif loss_type.find('VGG') >= 0:
30 | module = import_module('loss.vgg')
31 | loss_function = getattr(module, 'VGG')(
32 | loss_type[3:],
33 | rgb_range=args.rgb_range
34 | )
35 | elif loss_type.find('GAN') >= 0:
36 | module = import_module('loss.adversarial')
37 | loss_function = getattr(module, 'Adversarial')(
38 | args,
39 | loss_type
40 | )
41 | elif loss_type == 'FILTER':
42 | module = import_module('loss.filter')
43 | loss_function = getattr(module, 'Filter')(args)
44 | elif loss_type == 'SSIM':
45 | module = import_module('loss.mssim')
46 | loss_function = getattr(module, 'SSIM')(args)
47 | elif loss_type == 'MSSSIM':
48 | module = import_module('loss.mssim')
49 | loss_function = getattr(module, 'MSSSIM')(args)
50 |
51 | self.loss.append({
52 | 'type': loss_type,
53 | 'weight': float(weight),
54 | 'function': loss_function}
55 | )
56 | if loss_type.find('GAN') >= 0:
57 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
58 |
59 | if len(self.loss) > 1:
60 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
61 |
62 | for l in self.loss:
63 | if l['function'] is not None:
64 | if args.local_rank == 0:
65 | print('{:.3f} * {}'.format(l['weight'], l['type']))
66 | self.loss_module.append(l['function'])
67 |
68 | self.log = torch.Tensor()
69 |
70 | device = torch.device('cpu' if args.cpu else 'cuda')
71 | self.loss_module.to(device)
72 | if args.precision == 'half': self.loss_module.half()
73 | if not args.cpu and args.n_GPUs > 1:
74 | self.loss_module = nn.DataParallel(
75 | self.loss_module, range(args.n_GPUs)
76 | )
77 |
78 | if args.load != '': self.load(ckp.dir, cpu=args.cpu)
79 |
80 | def forward(self, sr, hr):
81 | losses = []
82 | for i, l in enumerate(self.loss):
83 | if l['function'] is not None:
84 | loss = l['function'](sr, hr)
85 | effective_loss = l['weight'] * loss
86 | losses.append(effective_loss)
87 | self.log[-1, i] += effective_loss.item()
88 | elif l['type'] == 'DIS':
89 | self.log[-1, i] += self.loss[i - 1]['function'].loss
90 |
91 | loss_sum = sum(losses)
92 | if len(self.loss) > 1:
93 | self.log[-1, -1] += loss_sum.item()
94 |
95 | return loss_sum
96 |
97 | def step(self):
98 | for l in self.get_loss_module():
99 | if hasattr(l, 'scheduler'):
100 | l.scheduler.step()
101 |
102 | def start_log(self):
103 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
104 |
105 | def end_log(self, n_batches):
106 | self.log[-1].div_(n_batches)
107 |
108 | def display_loss(self, batch):
109 | n_samples = batch + 1
110 | log = []
111 | for l, c in zip(self.loss, self.log[-1]):
112 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
113 |
114 | return ''.join(log)
115 |
116 | def plot_loss(self, apath, epoch):
117 | axis = np.linspace(1, epoch, epoch)
118 | for i, l in enumerate(self.loss):
119 | label = '{} Loss'.format(l['type'])
120 | fig = plt.figure()
121 | plt.title(label)
122 | plt.plot(axis, self.log[:, i].numpy(), label=label)
123 | plt.legend()
124 | plt.xlabel('Epochs')
125 | plt.ylabel('Loss')
126 | plt.grid(True)
127 | plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
128 | plt.close(fig)
129 |
130 | def get_loss_module(self):
131 | if self.n_GPUs == 1:
132 | return self.loss_module
133 | else:
134 | return self.loss_module.module
135 |
136 | def save(self, apath):
137 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
138 | torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
139 |
140 | def load(self, apath, cpu=False):
141 | if cpu:
142 | kwargs = {'map_location': lambda storage, loc: storage}
143 | else:
144 | kwargs = {}
145 |
146 | self.load_state_dict(torch.load(
147 | os.path.join(apath, 'loss.pt'),
148 | **kwargs
149 | ))
150 | self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
151 | for l in self.get_loss_module():
152 | if hasattr(l, 'scheduler'):
153 | for _ in range(len(self.log)): l.scheduler.step()
154 |
155 |
--------------------------------------------------------------------------------
/loss/adversarial.py:
--------------------------------------------------------------------------------
1 | import utility
2 | from types import SimpleNamespace
3 |
4 | from model import common
5 | from loss import discriminator
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 |
12 | class Adversarial(nn.Module):
13 | def __init__(self, args, gan_type):
14 | super(Adversarial, self).__init__()
15 | self.gan_type = gan_type
16 | self.gan_k = args.gan_k
17 | self.dis = discriminator.Discriminator(args)
18 | # if gan_type == 'WGAN_GP':
19 | if True:
20 | # see https://arxiv.org/pdf/1704.00028.pdf pp.4
21 | optim_dict = {
22 | 'optimizer': 'ADAM',
23 | 'betas': (0.5, 0.9),
24 | 'epsilon': 1e-8,
25 | 'lr': 1e-5,
26 | 'weight_decay': args.weight_decay,
27 | 'decay': args.decay,
28 | 'gamma': args.gamma
29 | }
30 | optim_args = SimpleNamespace(**optim_dict)
31 | else:
32 | optim_args = args
33 |
34 | self.optimizer = utility.make_optimizer(optim_args, self.dis)
35 |
36 | def forward(self, fake, real):
37 | # updating discriminator...
38 | self.loss = 0
39 | fake_detach = fake.detach() # do not backpropagate through G
40 | for _ in range(self.gan_k):
41 | self.optimizer.zero_grad()
42 | # d: B x 1 tensor
43 | d_fake = self.dis(fake_detach)
44 | d_real = self.dis(real)
45 | retain_graph = False
46 | if self.gan_type in ['GAN', 'SNGAN']:
47 | loss_d = self.bce(d_real, d_fake)
48 | elif self.gan_type.find('WGAN') >= 0:
49 | loss_d = (d_fake - d_real).mean()
50 | if self.gan_type.find('GP') >= 0:
51 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
52 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
53 | hat.requires_grad = True
54 | d_hat = self.dis(hat)
55 | gradients = torch.autograd.grad(
56 | outputs=d_hat.sum(), inputs=hat,
57 | retain_graph=True, create_graph=True, only_inputs=True
58 | )[0]
59 | gradients = gradients.view(gradients.size(0), -1)
60 | gradient_norm = gradients.norm(2, dim=1)
61 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
62 | loss_d += gradient_penalty
63 | # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
64 | elif self.gan_type == 'RGAN':
65 | better_real = d_real - d_fake.mean(dim=0, keepdim=True)
66 | better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
67 | loss_d = self.bce(better_real, better_fake)
68 | retain_graph = True
69 |
70 | # Discriminator update
71 | self.loss += loss_d.item()
72 | loss_d.backward(retain_graph=retain_graph)
73 | self.optimizer.step()
74 |
75 | if self.gan_type == 'WGAN':
76 | for p in self.dis.parameters():
77 | p.data.clamp_(-1, 1)
78 |
79 | self.loss /= self.gan_k
80 |
81 | # updating generator...
82 | d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
83 | if self.gan_type in ['GAN', 'SNGAN']:
84 | label_real = torch.ones_like(d_fake_bp)
85 | loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
86 | elif self.gan_type.find('WGAN') >= 0:
87 | loss_g = -d_fake_bp.mean()
88 | elif self.gan_type == 'RGAN':
89 | better_real = d_real.detach() - d_fake_bp.mean(dim=0, keepdim=True)
90 | better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True).detach()
91 | loss_g = self.bce(better_fake, better_real)
92 |
93 | # Generator loss
94 | return loss_g
95 |
96 | def state_dict(self, *args, **kwargs):
97 | state_discriminator = self.dis.state_dict(*args, **kwargs)
98 | state_optimizer = self.optimizer.state_dict()
99 |
100 | return dict(**state_discriminator, **state_optimizer)
101 |
102 | def bce(self, real, fake):
103 | label_real = torch.ones_like(real)
104 | label_fake = torch.zeros_like(fake)
105 | bce_real = F.binary_cross_entropy_with_logits(real, label_real)
106 | bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
107 | bce_loss = bce_real + bce_fake
108 | return bce_loss
109 |
110 | # Some references
111 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
112 | # OR
113 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
114 |
--------------------------------------------------------------------------------
/loss/discriminator.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch.nn as nn
4 |
5 | class Discriminator(nn.Module):
6 | '''
7 | output is not normalized
8 | '''
9 | def __init__(self, args, gan_type='GAN'):
10 | super(Discriminator, self).__init__()
11 |
12 | in_channels = args.n_colors
13 | out_channels = 32
14 | depth = 6
15 |
16 | def _block(_in_channels, _out_channels, stride=1):
17 |
18 | Conv = nn.Conv2d(
19 | _in_channels,
20 | _out_channels,
21 | 3,
22 | padding=1,
23 | stride=stride,
24 | bias=False
25 | )
26 |
27 | if gan_type == 'SNGAN':
28 | return nn.Sequential(
29 | spectral_norm(Conv),
30 | nn.BatchNorm2d(_out_channels),
31 | nn.LeakyReLU(negative_slope=0.2, inplace=True)
32 | )
33 | else:
34 | return nn.Sequential(
35 | Conv,
36 | nn.BatchNorm2d(_out_channels),
37 | nn.LeakyReLU(negative_slope=0.2, inplace=True)
38 | )
39 |
40 | m_features = [_block(in_channels, out_channels)]
41 | for i in range(depth):
42 | in_channels = out_channels
43 | # if i % 2 == 1:
44 | # stride = 1
45 | # out_channels *= 2
46 | # else:
47 | out_channels *= 2
48 | stride = 2
49 | m_features.append(_block(in_channels, out_channels, stride=stride))
50 |
51 | patch_size = args.patch_size // 2**(depth-1)
52 |
53 | # print(out_channels, patch_size)
54 |
55 | m_classifier = [
56 | nn.Flatten(),
57 | nn.Linear(out_channels*patch_size**2, 512),
58 | nn.LeakyReLU(0.2, True),
59 | nn.Linear(512, 1)
60 | ]
61 |
62 | self.features = nn.Sequential(*m_features)
63 | self.classifier = nn.Sequential(*m_classifier)
64 |
65 | def forward(self, x):
66 | features = self.features(x)
67 | # print(features.shape)
68 | output = self.classifier(features)
69 |
70 | return output
71 |
72 |
--------------------------------------------------------------------------------
/loss/filter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class Filter(nn.Module):
6 | def __init__(self, args):
7 | super().__init__()
8 | self.args = args
9 |
10 | kernel = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])
11 | self.conv = nn.Conv2d(args.n_colors, args.n_colors, 3, 3)
12 | with torch.no_grad():
13 | self.conv.weight.copy_(kernel.float())
14 | self.loss = nn.L1Loss()
15 |
16 | def forward(self, x, y):
17 | preds_x = self.conv(x)
18 | preds_y = self.conv(y)
19 |
20 | return self.loss(preds_x, preds_y)
21 |
--------------------------------------------------------------------------------
/loss/hist_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class HistEntropy(nn.Module):
6 | def __init__(self, args):
7 | super().__init__()
8 | self.args = args
9 |
10 | def forward(self, x):
11 | p = torch.softmax(x, dim=1)
12 | logp = torch.log_softmax(x, dim=1)
13 |
14 | entropy = (-p * logp).sum(dim=(2, 3)).mean()
15 |
16 | return entropy
17 |
--------------------------------------------------------------------------------
/loss/mssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from math import exp
4 | import numpy as np
5 |
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 |
12 | def create_window(window_size, channel=1):
13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
15 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
16 | return window
17 |
18 |
19 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
20 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
21 | if val_range is None:
22 | if torch.max(img1) > 128:
23 | max_val = 255
24 | else:
25 | max_val = 1
26 |
27 | if torch.min(img1) < -0.5:
28 | min_val = -1
29 | else:
30 | min_val = 0
31 | L = max_val - min_val
32 | else:
33 | L = val_range
34 |
35 | padd = 0
36 | (_, channel, height, width) = img1.size()
37 | if window is None:
38 | real_size = min(window_size, height, width)
39 | window = create_window(real_size, channel=channel).to(img1.device)
40 |
41 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
42 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
43 |
44 | mu1_sq = mu1.pow(2)
45 | mu2_sq = mu2.pow(2)
46 | mu1_mu2 = mu1 * mu2
47 |
48 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
49 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
50 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
51 |
52 | C1 = (0.01 * L) ** 2
53 | C2 = (0.03 * L) ** 2
54 |
55 | v1 = 2.0 * sigma12 + C2
56 | v2 = sigma1_sq + sigma2_sq + C2
57 | cs = torch.mean(v1 / v2) # contrast sensitivity
58 |
59 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
60 |
61 | if size_average:
62 | ret = ssim_map.mean()
63 | else:
64 | ret = ssim_map.mean(1).mean(1).mean(1)
65 |
66 | if full:
67 | return ret, cs
68 | return ret
69 |
70 |
71 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None):
72 | device = img1.device
73 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
74 | levels = weights.size()[0]
75 | ssims = []
76 | mcs = []
77 | for _ in range(levels):
78 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
79 |
80 | # Relu normalize (not compliant with original definition)
81 | if normalize == "relu":
82 | ssims.append(torch.relu(sim))
83 | mcs.append(torch.relu(cs))
84 | else:
85 | ssims.append(sim)
86 | mcs.append(cs)
87 |
88 | img1 = F.avg_pool2d(img1, (2, 2))
89 | img2 = F.avg_pool2d(img2, (2, 2))
90 |
91 | ssims = torch.stack(ssims)
92 | mcs = torch.stack(mcs)
93 |
94 | # Simple normalize (not compliant with original definition)
95 | # TODO: remove support for normalize == True (kept for backward support)
96 | if normalize == "simple" or normalize == True:
97 | ssims = (ssims + 1) / 2
98 | mcs = (mcs + 1) / 2
99 |
100 | pow1 = mcs ** weights
101 | pow2 = ssims ** weights
102 |
103 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
104 | output = torch.prod(pow1[:-1] * pow2[-1])
105 | return output
106 |
107 |
108 | # Classes to re-use window
109 | class SSIM(torch.nn.Module):
110 | def __init__(self, window_size=11, size_average=True, val_range=None):
111 | super(SSIM, self).__init__()
112 | self.window_size = window_size
113 | self.size_average = size_average
114 | self.val_range = val_range
115 |
116 | # Assume 1 channel for SSIM
117 | self.channel = 1
118 | self.window = create_window(window_size)
119 |
120 | def forward(self, img1, img2):
121 | (_, channel, _, _) = img1.size()
122 |
123 | if channel == self.channel and self.window.dtype == img1.dtype:
124 | window = self.window
125 | else:
126 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
127 | self.window = window
128 | self.channel = channel
129 |
130 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
131 |
132 | class MSSSIM(torch.nn.Module):
133 | def __init__(self, window_size=11, size_average=True, channel=3):
134 | super(MSSSIM, self).__init__()
135 | self.window_size = window_size
136 | self.size_average = size_average
137 | self.channel = channel
138 |
139 | def forward(self, img1, img2):
140 | # TODO: store window between calls if possible
141 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
--------------------------------------------------------------------------------
/loss/vgg.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision.models as models
7 |
8 | class VGG(nn.Module):
9 | def __init__(self, conv_index, rgb_range=1):
10 | super(VGG, self).__init__()
11 | vgg_features = models.vgg19(pretrained=True).features
12 | modules = [m for m in vgg_features]
13 | if conv_index.find('22') >= 0:
14 | self.vgg = nn.Sequential(*modules[:8])
15 | elif conv_index.find('54') >= 0:
16 | self.vgg = nn.Sequential(*modules[:35])
17 |
18 | vgg_mean = (0.485, 0.456, 0.406)
19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
21 | for p in self.parameters():
22 | p.requires_grad = False
23 |
24 | def forward(self, sr, hr):
25 | def _forward(x):
26 | # x = self.sub_mean(x)
27 | x = self.vgg(x)
28 | return x
29 |
30 | sr = sr.repeat(1, 3, 1, 1)
31 | hr = hr.repeat(1, 3, 1, 1)
32 |
33 | vgg_sr = _forward(sr)
34 | with torch.no_grad():
35 | vgg_hr = _forward(hr.detach())
36 |
37 | loss = F.mse_loss(vgg_sr, vgg_hr)
38 |
39 | return loss
40 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | from torch.utils.data import DataLoader
5 | from torchvision import transforms as T
6 |
7 | import utility
8 | import model
9 | import loss
10 | from option import args
11 | from trainer import Trainer
12 | from datasets.synthetic_burst_train_set import SyntheticBurst
13 | from datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB
14 | import torch.multiprocessing as mp
15 | import torch.backends.cudnn as cudnn
16 | import torch.utils.data.distributed
17 |
18 | try:
19 | import apex
20 | from apex.parallel import DistributedDataParallel as DDP
21 | from apex.fp16_utils import *
22 | from apex import amp, optimizers
23 | from apex.multi_tensor_apply import multi_tensor_applier
24 | except ImportError:
25 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
26 |
27 |
28 | def init_seeds(seed=0, cuda_deterministic=True):
29 | random.seed(seed)
30 | np.random.seed(seed)
31 | torch.manual_seed(seed)
32 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
33 | if cuda_deterministic: # slower, more reproducible
34 | cudnn.deterministic = True
35 | cudnn.benchmark = False
36 | else: # faster, less reproducible
37 | cudnn.deterministic = False
38 | cudnn.benchmark = True
39 |
40 |
41 | checkpoint = utility.checkpoint(args)
42 |
43 |
44 | def main():
45 | if args.n_GPUs > 1:
46 | mp.spawn(main_worker, nprocs=args.n_GPUs, args=(args.n_GPUs, args))
47 | else:
48 | main_worker(0, args.n_GPUs, args)
49 |
50 |
51 | def main_worker(local_rank, nprocs, args):
52 | if checkpoint.ok:
53 | args.local_rank = local_rank
54 | if nprocs > 1:
55 | init_seeds(local_rank+1)
56 | cudnn.benchmark = True
57 | utility.setup(local_rank, nprocs)
58 | torch.cuda.set_device(args.local_rank)
59 |
60 | batch_size = int(args.batch_size / nprocs)
61 | train_zurich_raw2rgb = ZurichRAW2RGB(root=args.root, split='train')
62 | train_data = SyntheticBurst(train_zurich_raw2rgb, burst_size=args.burst_size, crop_sz=args.patch_size)
63 |
64 | valid_zurich_raw2rgb = ZurichRAW2RGB(root=args.root, split='test')
65 | valid_data = SyntheticBurst(valid_zurich_raw2rgb, burst_size=args.burst_size, crop_sz=384)
66 |
67 | if nprocs > 1:
68 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
69 | valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data, shuffle=False)
70 | train_loader = DataLoader(dataset=train_data, batch_size=batch_size, num_workers=8,
71 | pin_memory=True, drop_last=True, sampler=train_sampler)
72 | valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, num_workers=4,
73 | pin_memory=True, drop_last=True, sampler=valid_sampler)
74 | else:
75 | train_sampler = None
76 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=8,
77 | shuffle=True, pin_memory=True, drop_last=True) # args.cpus
78 | valid_loader = DataLoader(dataset=valid_data, batch_size=args.batch_size, num_workers=4, shuffle=False,
79 | pin_memory=True, drop_last=True) # args.cpus
80 |
81 |
82 | _model = model.Model(args, checkpoint)
83 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None
84 | t = Trainer(args, train_loader, train_sampler, valid_loader, _model, _loss, checkpoint)
85 | while not t.terminate():
86 | t.train()
87 |
88 | del _model
89 | del _loss
90 | del train_loader
91 | del valid_loader
92 |
93 | utility.cleanup()
94 |
95 | checkpoint.done()
96 |
97 |
98 | if __name__ == '__main__':
99 | main()
100 |
--------------------------------------------------------------------------------
/model/DCNv2/DCNv2.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 1.0
2 | Name: DCNv2
3 | Version: 0.1
4 | Summary: deformable convolutional networks
5 | Home-page: https://github.com/charlesshang/DCNv2
6 | Author: charlesshang
7 | Author-email: UNKNOWN
8 | License: UNKNOWN
9 | Description: UNKNOWN
10 | Platform: UNKNOWN
11 |
--------------------------------------------------------------------------------
/model/DCNv2/DCNv2.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | README.md
2 | setup.py
3 | /data/work/pylibs/DCNv2-pytorch_1.6/src/vision.cpp
4 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_cpu.cpp
5 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_im2col_cpu.cpp
6 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_psroi_pooling_cpu.cpp
7 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_cuda.cu
8 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_im2col_cuda.cu
9 | /data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_psroi_pooling_cuda.cu
10 | DCNv2.egg-info/PKG-INFO
11 | DCNv2.egg-info/SOURCES.txt
12 | DCNv2.egg-info/dependency_links.txt
13 | DCNv2.egg-info/top_level.txt
--------------------------------------------------------------------------------
/model/DCNv2/DCNv2.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/DCNv2/DCNv2.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | _ext
2 |
--------------------------------------------------------------------------------
/model/DCNv2/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, Charles Shang
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/model/DCNv2/README.md:
--------------------------------------------------------------------------------
1 | ## Deformable Convolutional Networks V2 with Pytorch 1.0
2 |
3 | ### Build
4 | ```bash
5 | ./make.sh # build
6 | python test.py # run examples and gradient check
7 | ```
8 |
9 | ### An Example
10 | - deformable conv
11 | ```python
12 | from dcn_v2 import DCN
13 | input = torch.randn(2, 64, 128, 128).cuda()
14 | # wrap all things (offset and mask) in DCN
15 | dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda()
16 | output = dcn(input)
17 | print(output.shape)
18 | ```
19 | - deformable roi pooling
20 | ```python
21 | from dcn_v2 import DCNPooling
22 | input = torch.randn(2, 32, 64, 64).cuda()
23 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
24 | x = torch.randint(256, (20, 1)).cuda().float()
25 | y = torch.randint(256, (20, 1)).cuda().float()
26 | w = torch.randint(64, (20, 1)).cuda().float()
27 | h = torch.randint(64, (20, 1)).cuda().float()
28 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
29 |
30 | # mdformable pooling (V2)
31 | # wrap all things (offset and mask) in DCNPooling
32 | dpooling = DCNPooling(spatial_scale=1.0 / 4,
33 | pooled_size=7,
34 | output_dim=32,
35 | no_trans=False,
36 | group_size=1,
37 | trans_std=0.1).cuda()
38 |
39 | dout = dpooling(input, rois)
40 | ```
41 | ### Note
42 | Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with,
43 | ```bash
44 | git checkout pytorch_0.4
45 | ```
46 |
47 | ### Known Issues:
48 |
49 | - [x] Gradient check w.r.t offset (solved)
50 | - [ ] Backward is not reentrant (minor)
51 |
52 | This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op).
53 |
54 | I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes.
55 | However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some
56 | non-differential points?
57 |
58 | Update: all gradient check passes with double precision.
59 |
60 | Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for
61 | float `<1e-15` for double),
62 | so it may not be a serious problem (?)
63 |
64 | Please post an issue or PR if you have any comments.
65 |
--------------------------------------------------------------------------------
/model/DCNv2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/__init__.py
--------------------------------------------------------------------------------
/model/DCNv2/__pycache__/dcn_v2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/__pycache__/dcn_v2.cpython-37.pyc
--------------------------------------------------------------------------------
/model/DCNv2/build/lib.linux-x86_64-3.7/_ext.cpython-37m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/lib.linux-x86_64-3.7/_ext.cpython-37m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_cpu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_cpu.o
--------------------------------------------------------------------------------
/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_im2col_cpu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_im2col_cpu.o
--------------------------------------------------------------------------------
/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_psroi_pooling_cpu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cpu/dcn_v2_psroi_pooling_cpu.o
--------------------------------------------------------------------------------
/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_cuda.o
--------------------------------------------------------------------------------
/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_im2col_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_im2col_cuda.o
--------------------------------------------------------------------------------
/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_psroi_pooling_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/cuda/dcn_v2_psroi_pooling_cuda.o
--------------------------------------------------------------------------------
/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/vision.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/build/temp.linux-x86_64-3.7/data/work/pylibs/DCNv2-pytorch_1.6/src/vision.o
--------------------------------------------------------------------------------
/model/DCNv2/dist/DCNv2-0.1-py3.7-linux-x86_64.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/model/DCNv2/dist/DCNv2-0.1-py3.7-linux-x86_64.egg
--------------------------------------------------------------------------------
/model/DCNv2/files.txt:
--------------------------------------------------------------------------------
1 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.cpython-37m-x86_64-linux-gnu.so
2 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.py
3 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/PKG-INFO
4 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/SOURCES.txt
5 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/dependency_links.txt
6 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/native_libs.txt
7 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/not-zip-safe
8 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/top_level.txt
9 | /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/__pycache__/_ext.cpython-37.pyc
10 |
--------------------------------------------------------------------------------
/model/DCNv2/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python setup.py build develop
3 |
--------------------------------------------------------------------------------
/model/DCNv2/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import glob
4 | import os
5 |
6 | import torch
7 | from setuptools import find_packages, setup
8 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
9 |
10 | requirements = ["torch", "torchvision"]
11 |
12 |
13 | def get_extensions():
14 | this_dir = os.path.dirname(os.path.abspath(__file__))
15 | extensions_dir = os.path.join(this_dir, "src")
16 |
17 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
18 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
19 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
20 |
21 | os.environ["CC"] = "g++"
22 | sources = main_file + source_cpu
23 | extension = CppExtension
24 | extra_compile_args = {"cxx": []}
25 | define_macros = []
26 |
27 | if True:
28 | extension = CUDAExtension
29 | sources += source_cuda
30 | define_macros += [("WITH_CUDA", None)]
31 | extra_compile_args["nvcc"] = [
32 | "-DCUDA_HAS_FP16=1",
33 | "-D__CUDA_NO_HALF_OPERATORS__",
34 | "-D__CUDA_NO_HALF_CONVERSIONS__",
35 | "-D__CUDA_NO_HALF2_OPERATORS__",
36 | ]
37 | else:
38 | # raise NotImplementedError('Cuda is not available')
39 | pass
40 |
41 | sources = [os.path.join(extensions_dir, s) for s in sources]
42 | include_dirs = [extensions_dir]
43 | ext_modules = [
44 | extension(
45 | "_ext",
46 | sources,
47 | include_dirs=include_dirs,
48 | define_macros=define_macros,
49 | extra_compile_args=extra_compile_args,
50 | )
51 | ]
52 | return ext_modules
53 |
54 |
55 | setup(
56 | name="DCNv2",
57 | version="0.1",
58 | author="charlesshang",
59 | url="https://github.com/charlesshang/DCNv2",
60 | description="deformable convolutional networks",
61 | packages=find_packages(exclude=("configs", "tests")),
62 | # install_requires=requirements,
63 | ext_modules=get_extensions(),
64 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
65 | )
66 |
--------------------------------------------------------------------------------
/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.h:
--------------------------------------------------------------------------------
1 |
2 | /*!
3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
4 | *
5 | * COPYRIGHT
6 | *
7 | * All contributions by the University of California:
8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
9 | * All rights reserved.
10 | *
11 | * All other contributions:
12 | * Copyright (c) 2014-2017, the respective contributors
13 | * All rights reserved.
14 | *
15 | * Caffe uses a shared copyright model: each contributor holds copyright over
16 | * their contributions to Caffe. The project versioning records all such
17 | * contribution and copyright details. If a contributor wants to further mark
18 | * their specific copyright on a particular contribution, they should indicate
19 | * their copyright solely in the commit message of the change when it is
20 | * committed.
21 | *
22 | * LICENSE
23 | *
24 | * Redistribution and use in source and binary forms, with or without
25 | * modification, are permitted provided that the following conditions are met:
26 | *
27 | * 1. Redistributions of source code must retain the above copyright notice, this
28 | * list of conditions and the following disclaimer.
29 | * 2. Redistributions in binary form must reproduce the above copyright notice,
30 | * this list of conditions and the following disclaimer in the documentation
31 | * and/or other materials provided with the distribution.
32 | *
33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
43 | *
44 | * CONTRIBUTION AGREEMENT
45 | *
46 | * By contributing to the BVLC/caffe repository through pull-request, comment,
47 | * or otherwise, the contributor releases their content to the
48 | * license and copyright terms herein.
49 | *
50 | ***************** END Caffe Copyright Notice and Disclaimer ********************
51 | *
52 | * Copyright (c) 2018 Microsoft
53 | * Licensed under The MIT License [see LICENSE for details]
54 | * \file modulated_deformable_im2col.h
55 | * \brief Function definitions of converting an image to
56 | * column matrix based on kernel, padding, dilation, and offset.
57 | * These functions are mainly used in deformable convolution operators.
58 | * \ref: https://arxiv.org/abs/1811.11168
59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
60 | */
61 |
62 | /***************** Adapted by Charles Shang *********************/
63 | // modified from the CUDA version for CPU use by Daniel K. Suhendro
64 |
65 | #ifndef DCN_V2_IM2COL_CPU
66 | #define DCN_V2_IM2COL_CPU
67 |
68 | #ifdef __cplusplus
69 | extern "C"
70 | {
71 | #endif
72 |
73 | void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask,
74 | const int batch_size, const int channels, const int height_im, const int width_im,
75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
77 | const int dilation_h, const int dilation_w,
78 | const int deformable_group, float *data_col);
79 |
80 | void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask,
81 | const int batch_size, const int channels, const int height_im, const int width_im,
82 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
83 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
84 | const int dilation_h, const int dilation_w,
85 | const int deformable_group, float *grad_im);
86 |
87 | void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,
88 | const int batch_size, const int channels, const int height_im, const int width_im,
89 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
90 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
91 | const int dilation_h, const int dilation_w,
92 | const int deformable_group,
93 | float *grad_offset, float *grad_mask);
94 |
95 | #ifdef __cplusplus
96 | }
97 | #endif
98 |
99 | #endif
--------------------------------------------------------------------------------
/model/DCNv2/src/cpu/vision.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor
5 | dcn_v2_cpu_forward(const at::Tensor &input,
6 | const at::Tensor &weight,
7 | const at::Tensor &bias,
8 | const at::Tensor &offset,
9 | const at::Tensor &mask,
10 | const int kernel_h,
11 | const int kernel_w,
12 | const int stride_h,
13 | const int stride_w,
14 | const int pad_h,
15 | const int pad_w,
16 | const int dilation_h,
17 | const int dilation_w,
18 | const int deformable_group);
19 |
20 | std::vector
21 | dcn_v2_cpu_backward(const at::Tensor &input,
22 | const at::Tensor &weight,
23 | const at::Tensor &bias,
24 | const at::Tensor &offset,
25 | const at::Tensor &mask,
26 | const at::Tensor &grad_output,
27 | int kernel_h, int kernel_w,
28 | int stride_h, int stride_w,
29 | int pad_h, int pad_w,
30 | int dilation_h, int dilation_w,
31 | int deformable_group);
32 |
33 |
34 | std::tuple
35 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
36 | const at::Tensor &bbox,
37 | const at::Tensor &trans,
38 | const int no_trans,
39 | const float spatial_scale,
40 | const int output_dim,
41 | const int group_size,
42 | const int pooled_size,
43 | const int part_size,
44 | const int sample_per_part,
45 | const float trans_std);
46 |
47 | std::tuple
48 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
49 | const at::Tensor &input,
50 | const at::Tensor &bbox,
51 | const at::Tensor &trans,
52 | const at::Tensor &top_count,
53 | const int no_trans,
54 | const float spatial_scale,
55 | const int output_dim,
56 | const int group_size,
57 | const int pooled_size,
58 | const int part_size,
59 | const int sample_per_part,
60 | const float trans_std);
--------------------------------------------------------------------------------
/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.h:
--------------------------------------------------------------------------------
1 |
2 | /*!
3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
4 | *
5 | * COPYRIGHT
6 | *
7 | * All contributions by the University of California:
8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
9 | * All rights reserved.
10 | *
11 | * All other contributions:
12 | * Copyright (c) 2014-2017, the respective contributors
13 | * All rights reserved.
14 | *
15 | * Caffe uses a shared copyright model: each contributor holds copyright over
16 | * their contributions to Caffe. The project versioning records all such
17 | * contribution and copyright details. If a contributor wants to further mark
18 | * their specific copyright on a particular contribution, they should indicate
19 | * their copyright solely in the commit message of the change when it is
20 | * committed.
21 | *
22 | * LICENSE
23 | *
24 | * Redistribution and use in source and binary forms, with or without
25 | * modification, are permitted provided that the following conditions are met:
26 | *
27 | * 1. Redistributions of source code must retain the above copyright notice, this
28 | * list of conditions and the following disclaimer.
29 | * 2. Redistributions in binary form must reproduce the above copyright notice,
30 | * this list of conditions and the following disclaimer in the documentation
31 | * and/or other materials provided with the distribution.
32 | *
33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
43 | *
44 | * CONTRIBUTION AGREEMENT
45 | *
46 | * By contributing to the BVLC/caffe repository through pull-request, comment,
47 | * or otherwise, the contributor releases their content to the
48 | * license and copyright terms herein.
49 | *
50 | ***************** END Caffe Copyright Notice and Disclaimer ********************
51 | *
52 | * Copyright (c) 2018 Microsoft
53 | * Licensed under The MIT License [see LICENSE for details]
54 | * \file modulated_deformable_im2col.h
55 | * \brief Function definitions of converting an image to
56 | * column matrix based on kernel, padding, dilation, and offset.
57 | * These functions are mainly used in deformable convolution operators.
58 | * \ref: https://arxiv.org/abs/1811.11168
59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
60 | */
61 |
62 | /***************** Adapted by Charles Shang *********************/
63 |
64 | #ifndef DCN_V2_IM2COL_CUDA
65 | #define DCN_V2_IM2COL_CUDA
66 |
67 | #ifdef __cplusplus
68 | extern "C"
69 | {
70 | #endif
71 |
72 | void modulated_deformable_im2col_cuda(cudaStream_t stream,
73 | const float *data_im, const float *data_offset, const float *data_mask,
74 | const int batch_size, const int channels, const int height_im, const int width_im,
75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
77 | const int dilation_h, const int dilation_w,
78 | const int deformable_group, float *data_col);
79 |
80 | void modulated_deformable_col2im_cuda(cudaStream_t stream,
81 | const float *data_col, const float *data_offset, const float *data_mask,
82 | const int batch_size, const int channels, const int height_im, const int width_im,
83 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
84 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
85 | const int dilation_h, const int dilation_w,
86 | const int deformable_group, float *grad_im);
87 |
88 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,
89 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,
90 | const int batch_size, const int channels, const int height_im, const int width_im,
91 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
92 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
93 | const int dilation_h, const int dilation_w,
94 | const int deformable_group,
95 | float *grad_offset, float *grad_mask);
96 |
97 | #ifdef __cplusplus
98 | }
99 | #endif
100 |
101 | #endif
--------------------------------------------------------------------------------
/model/DCNv2/src/cuda/vision.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | at::Tensor
5 | dcn_v2_cuda_forward(const at::Tensor &input,
6 | const at::Tensor &weight,
7 | const at::Tensor &bias,
8 | const at::Tensor &offset,
9 | const at::Tensor &mask,
10 | const int kernel_h,
11 | const int kernel_w,
12 | const int stride_h,
13 | const int stride_w,
14 | const int pad_h,
15 | const int pad_w,
16 | const int dilation_h,
17 | const int dilation_w,
18 | const int deformable_group);
19 |
20 | std::vector
21 | dcn_v2_cuda_backward(const at::Tensor &input,
22 | const at::Tensor &weight,
23 | const at::Tensor &bias,
24 | const at::Tensor &offset,
25 | const at::Tensor &mask,
26 | const at::Tensor &grad_output,
27 | int kernel_h, int kernel_w,
28 | int stride_h, int stride_w,
29 | int pad_h, int pad_w,
30 | int dilation_h, int dilation_w,
31 | int deformable_group);
32 |
33 |
34 | std::tuple
35 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,
36 | const at::Tensor &bbox,
37 | const at::Tensor &trans,
38 | const int no_trans,
39 | const float spatial_scale,
40 | const int output_dim,
41 | const int group_size,
42 | const int pooled_size,
43 | const int part_size,
44 | const int sample_per_part,
45 | const float trans_std);
46 |
47 | std::tuple
48 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,
49 | const at::Tensor &input,
50 | const at::Tensor &bbox,
51 | const at::Tensor &trans,
52 | const at::Tensor &top_count,
53 | const int no_trans,
54 | const float spatial_scale,
55 | const int output_dim,
56 | const int group_size,
57 | const int pooled_size,
58 | const int part_size,
59 | const int sample_per_part,
60 | const float trans_std);
--------------------------------------------------------------------------------
/model/DCNv2/src/dcn_v2.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "cpu/vision.h"
4 |
5 | #ifdef WITH_CUDA
6 | #include "cuda/vision.h"
7 | #endif
8 |
9 | at::Tensor
10 | dcn_v2_forward(const at::Tensor &input,
11 | const at::Tensor &weight,
12 | const at::Tensor &bias,
13 | const at::Tensor &offset,
14 | const at::Tensor &mask,
15 | const int kernel_h,
16 | const int kernel_w,
17 | const int stride_h,
18 | const int stride_w,
19 | const int pad_h,
20 | const int pad_w,
21 | const int dilation_h,
22 | const int dilation_w,
23 | const int deformable_group)
24 | {
25 | if (input.type().is_cuda())
26 | {
27 | #ifdef WITH_CUDA
28 | return dcn_v2_cuda_forward(input, weight, bias, offset, mask,
29 | kernel_h, kernel_w,
30 | stride_h, stride_w,
31 | pad_h, pad_w,
32 | dilation_h, dilation_w,
33 | deformable_group);
34 | #else
35 | AT_ERROR("Not compiled with GPU support");
36 | #endif
37 | }
38 | else{
39 | return dcn_v2_cpu_forward(input, weight, bias, offset, mask,
40 | kernel_h, kernel_w,
41 | stride_h, stride_w,
42 | pad_h, pad_w,
43 | dilation_h, dilation_w,
44 | deformable_group);
45 | }
46 | }
47 |
48 | std::vector
49 | dcn_v2_backward(const at::Tensor &input,
50 | const at::Tensor &weight,
51 | const at::Tensor &bias,
52 | const at::Tensor &offset,
53 | const at::Tensor &mask,
54 | const at::Tensor &grad_output,
55 | int kernel_h, int kernel_w,
56 | int stride_h, int stride_w,
57 | int pad_h, int pad_w,
58 | int dilation_h, int dilation_w,
59 | int deformable_group)
60 | {
61 | if (input.type().is_cuda())
62 | {
63 | #ifdef WITH_CUDA
64 | return dcn_v2_cuda_backward(input,
65 | weight,
66 | bias,
67 | offset,
68 | mask,
69 | grad_output,
70 | kernel_h, kernel_w,
71 | stride_h, stride_w,
72 | pad_h, pad_w,
73 | dilation_h, dilation_w,
74 | deformable_group);
75 | #else
76 | AT_ERROR("Not compiled with GPU support");
77 | #endif
78 | }
79 | else{
80 | return dcn_v2_cpu_backward(input,
81 | weight,
82 | bias,
83 | offset,
84 | mask,
85 | grad_output,
86 | kernel_h, kernel_w,
87 | stride_h, stride_w,
88 | pad_h, pad_w,
89 | dilation_h, dilation_w,
90 | deformable_group);
91 | }
92 | }
93 |
94 | std::tuple
95 | dcn_v2_psroi_pooling_forward(const at::Tensor &input,
96 | const at::Tensor &bbox,
97 | const at::Tensor &trans,
98 | const int no_trans,
99 | const float spatial_scale,
100 | const int output_dim,
101 | const int group_size,
102 | const int pooled_size,
103 | const int part_size,
104 | const int sample_per_part,
105 | const float trans_std)
106 | {
107 | if (input.type().is_cuda())
108 | {
109 | #ifdef WITH_CUDA
110 | return dcn_v2_psroi_pooling_cuda_forward(input,
111 | bbox,
112 | trans,
113 | no_trans,
114 | spatial_scale,
115 | output_dim,
116 | group_size,
117 | pooled_size,
118 | part_size,
119 | sample_per_part,
120 | trans_std);
121 | #else
122 | AT_ERROR("Not compiled with GPU support");
123 | #endif
124 | }
125 | else{
126 | return dcn_v2_psroi_pooling_cpu_forward(input,
127 | bbox,
128 | trans,
129 | no_trans,
130 | spatial_scale,
131 | output_dim,
132 | group_size,
133 | pooled_size,
134 | part_size,
135 | sample_per_part,
136 | trans_std);
137 | }
138 | }
139 |
140 | std::tuple
141 | dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad,
142 | const at::Tensor &input,
143 | const at::Tensor &bbox,
144 | const at::Tensor &trans,
145 | const at::Tensor &top_count,
146 | const int no_trans,
147 | const float spatial_scale,
148 | const int output_dim,
149 | const int group_size,
150 | const int pooled_size,
151 | const int part_size,
152 | const int sample_per_part,
153 | const float trans_std)
154 | {
155 | if (input.type().is_cuda())
156 | {
157 | #ifdef WITH_CUDA
158 | return dcn_v2_psroi_pooling_cuda_backward(out_grad,
159 | input,
160 | bbox,
161 | trans,
162 | top_count,
163 | no_trans,
164 | spatial_scale,
165 | output_dim,
166 | group_size,
167 | pooled_size,
168 | part_size,
169 | sample_per_part,
170 | trans_std);
171 | #else
172 | AT_ERROR("Not compiled with GPU support");
173 | #endif
174 | }
175 | else{
176 | return dcn_v2_psroi_pooling_cpu_backward(out_grad,
177 | input,
178 | bbox,
179 | trans,
180 | top_count,
181 | no_trans,
182 | spatial_scale,
183 | output_dim,
184 | group_size,
185 | pooled_size,
186 | part_size,
187 | sample_per_part,
188 | trans_std);
189 | }
190 | }
--------------------------------------------------------------------------------
/model/DCNv2/src/vision.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include "dcn_v2.h"
3 |
4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5 | m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward");
6 | m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward");
7 | m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward");
8 | m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward");
9 | }
10 |
--------------------------------------------------------------------------------
/model/DCNv2/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import absolute_import
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import time
7 | import torch
8 | import torch.nn as nn
9 | from torch.autograd import gradcheck
10 |
11 | from dcn_v2 import dcn_v2_conv, DCNv2, DCN
12 | from dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling
13 |
14 | deformable_groups = 1
15 | N, inC, inH, inW = 2, 2, 4, 4
16 | outC = 2
17 | kH, kW = 3, 3
18 |
19 |
20 | def conv_identify(weight, bias):
21 | weight.data.zero_()
22 | bias.data.zero_()
23 | o, i, h, w = weight.shape
24 | y = h//2
25 | x = w//2
26 | for p in range(i):
27 | for q in range(o):
28 | if p == q:
29 | weight.data[q, p, y, x] = 1.0
30 |
31 |
32 | def check_zero_offset():
33 | conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW,
34 | kernel_size=(kH, kW),
35 | stride=(1, 1),
36 | padding=(1, 1),
37 | bias=True).cuda()
38 |
39 | conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW,
40 | kernel_size=(kH, kW),
41 | stride=(1, 1),
42 | padding=(1, 1),
43 | bias=True).cuda()
44 |
45 | dcn_v2 = DCNv2(inC, outC, (kH, kW),
46 | stride=1, padding=1, dilation=1,
47 | deformable_groups=deformable_groups).cuda()
48 |
49 | conv_offset.weight.data.zero_()
50 | conv_offset.bias.data.zero_()
51 | conv_mask.weight.data.zero_()
52 | conv_mask.bias.data.zero_()
53 | conv_identify(dcn_v2.weight, dcn_v2.bias)
54 |
55 | input = torch.randn(N, inC, inH, inW).cuda()
56 | offset = conv_offset(input)
57 | mask = conv_mask(input)
58 | mask = torch.sigmoid(mask)
59 | output = dcn_v2(input, offset, mask)
60 | output *= 2
61 | d = (input - output).abs().max()
62 | if d < 1e-10:
63 | print('Zero offset passed')
64 | else:
65 | print('Zero offset failed')
66 | print(input)
67 | print(output)
68 |
69 | def check_gradient_dconv():
70 |
71 | input = torch.rand(N, inC, inH, inW).cuda() * 0.01
72 | input.requires_grad = True
73 |
74 | offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2
75 | # offset.data.zero_()
76 | # offset.data -= 0.5
77 | offset.requires_grad = True
78 |
79 | mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda()
80 | # mask.data.zero_()
81 | mask.requires_grad = True
82 | mask = torch.sigmoid(mask)
83 |
84 | weight = torch.randn(outC, inC, kH, kW).cuda()
85 | weight.requires_grad = True
86 |
87 | bias = torch.rand(outC).cuda()
88 | bias.requires_grad = True
89 |
90 | stride = 1
91 | padding = 1
92 | dilation = 1
93 |
94 | print('check_gradient_dconv: ',
95 | gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias,
96 | stride, padding, dilation, deformable_groups),
97 | eps=1e-3, atol=1e-4, rtol=1e-2))
98 |
99 |
100 | def check_pooling_zero_offset():
101 |
102 | input = torch.randn(2, 16, 64, 64).cuda().zero_()
103 | input[0, :, 16:26, 16:26] = 1.
104 | input[1, :, 10:20, 20:30] = 2.
105 | rois = torch.tensor([
106 | [0, 65, 65, 103, 103],
107 | [1, 81, 41, 119, 79],
108 | ]).cuda().float()
109 | pooling = DCNv2Pooling(spatial_scale=1.0 / 4,
110 | pooled_size=7,
111 | output_dim=16,
112 | no_trans=True,
113 | group_size=1,
114 | trans_std=0.0).cuda()
115 |
116 | out = pooling(input, rois, input.new())
117 | s = ', '.join(['%f' % out[i, :, :, :].mean().item()
118 | for i in range(rois.shape[0])])
119 | print(s)
120 |
121 | dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,
122 | pooled_size=7,
123 | output_dim=16,
124 | no_trans=False,
125 | group_size=1,
126 | trans_std=0.0).cuda()
127 | offset = torch.randn(20, 2, 7, 7).cuda().zero_()
128 | dout = dpooling(input, rois, offset)
129 | s = ', '.join(['%f' % dout[i, :, :, :].mean().item()
130 | for i in range(rois.shape[0])])
131 | print(s)
132 |
133 |
134 | def check_gradient_dpooling():
135 | input = torch.randn(2, 3, 5, 5).cuda() * 0.01
136 | N = 4
137 | batch_inds = torch.randint(2, (N, 1)).cuda().float()
138 | x = torch.rand((N, 1)).cuda().float() * 15
139 | y = torch.rand((N, 1)).cuda().float() * 15
140 | w = torch.rand((N, 1)).cuda().float() * 10
141 | h = torch.rand((N, 1)).cuda().float() * 10
142 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
143 | offset = torch.randn(N, 2, 3, 3).cuda()
144 | input.requires_grad = True
145 | offset.requires_grad = True
146 |
147 | spatial_scale = 1.0 / 4
148 | pooled_size = 3
149 | output_dim = 3
150 | no_trans = 0
151 | group_size = 1
152 | trans_std = 0.0
153 | sample_per_part = 4
154 | part_size = pooled_size
155 |
156 | print('check_gradient_dpooling:',
157 | gradcheck(dcn_v2_pooling, (input, rois, offset,
158 | spatial_scale,
159 | pooled_size,
160 | output_dim,
161 | no_trans,
162 | group_size,
163 | part_size,
164 | sample_per_part,
165 | trans_std),
166 | eps=1e-4))
167 |
168 |
169 | def example_dconv():
170 | input = torch.randn(2, 64, 128, 128).cuda()
171 | # wrap all things (offset and mask) in DCN
172 | dcn = DCN(64, 64, kernel_size=(3, 3), stride=1,
173 | padding=1, deformable_groups=2).cuda()
174 | # print(dcn.weight.shape, input.shape)
175 | output = dcn(input)
176 | targert = output.new(*output.size())
177 | targert.data.uniform_(-0.01, 0.01)
178 | error = (targert - output).mean()
179 | error.backward()
180 | print(output.shape)
181 |
182 |
183 | def example_dpooling():
184 | input = torch.randn(2, 32, 64, 64).cuda()
185 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
186 | x = torch.randint(256, (20, 1)).cuda().float()
187 | y = torch.randint(256, (20, 1)).cuda().float()
188 | w = torch.randint(64, (20, 1)).cuda().float()
189 | h = torch.randint(64, (20, 1)).cuda().float()
190 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
191 | offset = torch.randn(20, 2, 7, 7).cuda()
192 | input.requires_grad = True
193 | offset.requires_grad = True
194 |
195 | # normal roi_align
196 | pooling = DCNv2Pooling(spatial_scale=1.0 / 4,
197 | pooled_size=7,
198 | output_dim=32,
199 | no_trans=True,
200 | group_size=1,
201 | trans_std=0.1).cuda()
202 |
203 | # deformable pooling
204 | dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,
205 | pooled_size=7,
206 | output_dim=32,
207 | no_trans=False,
208 | group_size=1,
209 | trans_std=0.1).cuda()
210 |
211 | out = pooling(input, rois, offset)
212 | dout = dpooling(input, rois, offset)
213 | print(out.shape)
214 | print(dout.shape)
215 |
216 | target_out = out.new(*out.size())
217 | target_out.data.uniform_(-0.01, 0.01)
218 | target_dout = dout.new(*dout.size())
219 | target_dout.data.uniform_(-0.01, 0.01)
220 | e = (target_out - out).mean()
221 | e.backward()
222 | e = (target_dout - dout).mean()
223 | e.backward()
224 |
225 |
226 | def example_mdpooling():
227 | input = torch.randn(2, 32, 64, 64).cuda()
228 | input.requires_grad = True
229 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
230 | x = torch.randint(256, (20, 1)).cuda().float()
231 | y = torch.randint(256, (20, 1)).cuda().float()
232 | w = torch.randint(64, (20, 1)).cuda().float()
233 | h = torch.randint(64, (20, 1)).cuda().float()
234 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
235 |
236 | # mdformable pooling (V2)
237 | dpooling = DCNPooling(spatial_scale=1.0 / 4,
238 | pooled_size=7,
239 | output_dim=32,
240 | no_trans=False,
241 | group_size=1,
242 | trans_std=0.1,
243 | deform_fc_dim=1024).cuda()
244 |
245 | dout = dpooling(input, rois)
246 | target = dout.new(*dout.size())
247 | target.data.uniform_(-0.1, 0.1)
248 | error = (target - dout).mean()
249 | error.backward()
250 | print(dout.shape)
251 |
252 |
253 | if __name__ == '__main__':
254 |
255 | example_dconv()
256 | example_dpooling()
257 | example_mdpooling()
258 |
259 | check_pooling_zero_offset()
260 | # zero offset check
261 | if inC == outC:
262 | check_zero_offset()
263 |
264 | check_gradient_dpooling()
265 | check_gradient_dconv()
266 | # """
267 | # ****** Note: backward is not reentrant error may not be a serious problem,
268 | # ****** since the max error is less than 1e-7,
269 | # ****** Still looking for what trigger this problem
270 | # """
271 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel as P
7 | import torch.utils.model_zoo
8 |
9 | class Model(nn.Module):
10 | def __init__(self, args, ckp):
11 | super(Model, self).__init__()
12 | self.args = args
13 | if args.local_rank == 0:
14 | print('Making model...')
15 |
16 | self.scale = args.scale
17 | self.idx_scale = 0
18 | self.input_large = (args.model == 'VDSR')
19 | self.self_ensemble = args.self_ensemble
20 | self.chop = args.chop
21 | self.precision = args.precision
22 | self.cpu = args.cpu
23 | self.device = torch.device('cpu' if args.cpu else 'cuda:%d' % args.local_rank)
24 | self.n_GPUs = args.n_GPUs
25 | self.save_models = args.save_models
26 |
27 | module = import_module('model.' + args.model.lower())
28 | self.model = module.make_model(args).to(self.device)
29 |
30 | if args.precision == 'half':
31 | self.model.half()
32 |
33 | self.load(
34 | ckp.get_path('model'),
35 | pre_train=args.pre_train,
36 | resume=args.resume,
37 | cpu=args.cpu
38 | )
39 |
40 | if args.n_GPUs > 1:
41 | self.model = nn.parallel.DistributedDataParallel(self.model,
42 | device_ids=[args.local_rank],
43 | find_unused_parameters=True
44 | )
45 |
46 | print(self.model, file=ckp.log_file)
47 |
48 | def forward(self, x, idx_scale):
49 | self.idx_scale = idx_scale
50 | if hasattr(self.model, 'set_scale'):
51 | self.model.set_scale(idx_scale)
52 |
53 | if self.training:
54 | # if self.n_GPUs > 1:
55 | return self.model(x)
56 | else:
57 | if self.chop:
58 | forward_function = self.forward_chop
59 | else:
60 | forward_function = self.model.forward
61 |
62 | if self.self_ensemble:
63 | return self.forward_x8(x, forward_function=forward_function)
64 | else:
65 | # return self.model(x)
66 | return forward_function(x)
67 |
68 | def save(self, apath, epoch, is_best=False):
69 | save_dirs = [os.path.join(apath, 'model_latest.pt')]
70 |
71 | if is_best:
72 | save_dirs.append(os.path.join(apath, 'model_best.pt'))
73 | if self.save_models:
74 | save_dirs.append(
75 | os.path.join(apath, 'model_{}.pt'.format(epoch))
76 | )
77 | if self.n_GPUs > 1:
78 | model = self.model.module
79 | else:
80 | model = self.model
81 |
82 | for s in save_dirs:
83 | torch.save(self.model.state_dict(), s)
84 |
85 | def load(self, apath, pre_train='', resume=-1, cpu=False):
86 | load_from = None
87 | kwargs = {}
88 | if cpu:
89 | kwargs = {'map_location': lambda storage, loc: storage}
90 |
91 | if resume == -1:
92 | load_from = torch.load(
93 | os.path.join(apath, 'model_latest.pt'),
94 | **kwargs
95 | )
96 | elif resume == 0:
97 | if pre_train == 'download':
98 | print('Download the model')
99 | dir_model = os.path.join('..', 'models')
100 | os.makedirs(dir_model, exist_ok=True)
101 | load_from = torch.utils.model_zoo.load_url(
102 | self.model.url,
103 | model_dir=dir_model,
104 | **kwargs
105 | )
106 | elif pre_train:
107 | print('Load the model from {}'.format(pre_train))
108 | map_location = {'cuda:%d' % 0: 'cuda:%d' % self.args.local_rank}
109 | load_from = torch.load(pre_train, map_location=map_location)
110 | else:
111 | load_from = torch.load(
112 | os.path.join(apath, 'model_{}.pt'.format(resume)),
113 | **kwargs
114 | )
115 |
116 | if load_from:
117 | self.model.load_state_dict(load_from)
118 | del load_from
119 |
120 | def forward_chop(self, *args, shave=10, min_size=160000):
121 | scale = 1 if self.input_large else self.scale[self.idx_scale]
122 | n_GPUs = min(self.n_GPUs, 4)
123 | # height, width
124 | h, w = args[0].size()[-2:]
125 |
126 | top = slice(0, h//2 + shave)
127 | bottom = slice(h - h//2 - shave, h)
128 | left = slice(0, w//2 + shave)
129 | right = slice(w - w//2 - shave, w)
130 | x_chops = [torch.cat([
131 | a[..., top, left],
132 | a[..., top, right],
133 | a[..., bottom, left],
134 | a[..., bottom, right]
135 | ]) for a in args]
136 |
137 | y_chops = []
138 | if h * w < 4 * min_size:
139 | for i in range(0, 4, n_GPUs):
140 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
141 | y = P.data_parallel(self.model, *x, range(n_GPUs))
142 | if not isinstance(y, list): y = [y]
143 | if not y_chops:
144 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
145 | else:
146 | for y_chop, _y in zip(y_chops, y):
147 | y_chop.extend(_y.chunk(n_GPUs, dim=0))
148 | else:
149 | for p in zip(*x_chops):
150 | y = self.forward_chop(*p, shave=shave, min_size=min_size)
151 | if not isinstance(y, list): y = [y]
152 | if not y_chops:
153 | y_chops = [[_y] for _y in y]
154 | else:
155 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y)
156 |
157 | h *= scale
158 | w *= scale
159 | top = slice(0, h//2)
160 | bottom = slice(h - h//2, h)
161 | bottom_r = slice(h//2 - h, None)
162 | left = slice(0, w//2)
163 | right = slice(w - w//2, w)
164 | right_r = slice(w//2 - w, None)
165 |
166 | # batch size, number of color channels
167 | b, c = y_chops[0][0].size()[:-2]
168 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops]
169 | for y_chop, _y in zip(y_chops, y):
170 | _y[..., top, left] = y_chop[0][..., top, left]
171 | _y[..., top, right] = y_chop[1][..., top, right_r]
172 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left]
173 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]
174 |
175 | if len(y) == 1: y = y[0]
176 |
177 | return y
178 |
179 | def forward_x8(self, *args, forward_function=None):
180 | def _transform(v, op):
181 | if self.precision != 'single': v = v.float()
182 |
183 | v2np = v.data.cpu().numpy()
184 | if op == 'v':
185 | tfnp = v2np[:, :, :, ::-1].copy()
186 | elif op == 'h':
187 | tfnp = v2np[:, :, ::-1, :].copy()
188 | elif op == 't':
189 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
190 |
191 | ret = torch.Tensor(tfnp).to(self.device)
192 | if self.precision == 'half': ret = ret.half()
193 |
194 | return ret
195 |
196 | list_x = []
197 | for a in args:
198 | x = [a]
199 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x])
200 |
201 | list_x.append(x)
202 |
203 | list_y = []
204 | for x in zip(*list_x):
205 | y = forward_function(*x)
206 | if not isinstance(y, list): y = [y]
207 | if not list_y:
208 | list_y = [[_y] for _y in y]
209 | else:
210 | for _list_y, _y in zip(list_y, y): _list_y.append(_y)
211 |
212 | for _list_y in list_y:
213 | for i in range(len(_list_y)):
214 | if i > 3:
215 | _list_y[i] = _transform(_list_y[i], 't')
216 | if i % 4 > 1:
217 | _list_y[i] = _transform(_list_y[i], 'h')
218 | if (i % 4) % 2 == 1:
219 | _list_y[i] = _transform(_list_y[i], 'v')
220 |
221 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y]
222 | if len(y) == 1: y = y[0]
223 |
224 | return y
225 |
--------------------------------------------------------------------------------
/model/common.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
10 | return nn.Conv2d(
11 | in_channels, out_channels, kernel_size,
12 | padding=(kernel_size // 2), bias=bias)
13 |
14 |
15 | class MeanShift(nn.Conv2d):
16 | def __init__(
17 | self, rgb_range,
18 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
19 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
20 | std = torch.Tensor(rgb_std)
21 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
22 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
23 | for p in self.parameters():
24 | p.requires_grad = False
25 |
26 |
27 | class BasicBlock(nn.Sequential):
28 | def __init__(
29 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
30 | bn=True, act=nn.ReLU(True)):
31 |
32 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
33 | if bn:
34 | m.append(nn.BatchNorm2d(out_channels))
35 | if act is not None:
36 | m.append(act)
37 |
38 | super(BasicBlock, self).__init__(*m)
39 |
40 |
41 | class ResBlock(nn.Module):
42 | def __init__(
43 | self, conv, n_feats, kernel_size,
44 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
45 |
46 | super(ResBlock, self).__init__()
47 | m = []
48 | for i in range(2):
49 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
50 | if bn:
51 | m.append(nn.BatchNorm2d(n_feats))
52 | if i == 0:
53 | m.append(act)
54 |
55 | self.body = nn.Sequential(*m)
56 | self.res_scale = res_scale
57 |
58 | def forward(self, x):
59 | res = self.body(x).mul(self.res_scale)
60 | res += x
61 |
62 | return res
63 |
64 |
65 | class Upsampler(nn.Sequential):
66 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
67 |
68 | m = []
69 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
70 | for _ in range(int(math.log(scale, 2))):
71 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
72 | m.append(nn.PixelShuffle(2))
73 | if bn:
74 | m.append(nn.BatchNorm2d(n_feats))
75 | if act == 'relu':
76 | m.append(nn.ReLU(True))
77 | elif act == 'prelu':
78 | m.append(nn.PReLU(n_feats))
79 |
80 | elif scale == 3:
81 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
82 | m.append(nn.PixelShuffle(3))
83 | if bn:
84 | m.append(nn.BatchNorm2d(n_feats))
85 | if act == 'relu':
86 | m.append(nn.ReLU(True))
87 | elif act == 'prelu':
88 | m.append(nn.PReLU(n_feats))
89 | else:
90 | raise NotImplementedError
91 |
92 | super(Upsampler, self).__init__(*m)
93 |
94 |
95 | class UpOnly(nn.Sequential):
96 | def __init__(self, scale):
97 |
98 | m = []
99 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
100 | for _ in range(int(math.log(scale, 2))):
101 | m.append(nn.PixelShuffle(2))
102 |
103 |
104 | elif scale == 3:
105 |
106 | m.append(nn.PixelShuffle(3))
107 |
108 | else:
109 | raise NotImplementedError
110 |
111 | super(UpOnly, self).__init__(*m)
112 |
113 |
114 | def lanczos_kernel(dx, a=3, N=None, dtype=None, device=None):
115 | '''
116 | Generates 1D Lanczos kernels for translation and interpolation.
117 | Args:
118 | dx : float, tensor (batch_size, 1), the translation in pixels to shift an image.
119 | a : int, number of lobes in the kernel support.
120 | If N is None, then the width is the kernel support (length of all lobes),
121 | S = 2(a + ceil(dx)) + 1.
122 | N : int, width of the kernel.
123 | If smaller than S then N is set to S.
124 | Returns:
125 | k: tensor (?, ?), lanczos kernel
126 | '''
127 |
128 | if not torch.is_tensor(dx):
129 | dx = torch.tensor(dx, dtype=dtype, device=device)
130 |
131 | if device is None:
132 | device = dx.device
133 |
134 | if dtype is None:
135 | dtype = dx.dtype
136 |
137 | D = dx.abs().ceil().int()
138 | S = 2 * (a + D) + 1 # width of kernel support
139 |
140 | S_max = S.max() if hasattr(S, 'shape') else S
141 |
142 | if (N is None) or (N < S_max):
143 | N = S
144 |
145 | Z = (N - S) // 2 # width of zeros beyond kernel support
146 |
147 | start = (-(a + D + Z)).min()
148 | end = (a + D + Z + 1).max()
149 | x = torch.arange(start, end, dtype=dtype, device=device).view(1, -1) - dx
150 | px = (np.pi * x) + 1e-3
151 |
152 | sin_px = torch.sin(px)
153 | sin_pxa = torch.sin(px / a)
154 |
155 | k = a * sin_px * sin_pxa / px ** 2 # sinc(x) masked by sinc(x/a)
156 |
157 | return k
158 |
159 |
160 | def lanczos_shift(img, shift, p=5, a=3):
161 | '''
162 | Shifts an image by convolving it with a Lanczos kernel.
163 | Lanczos interpolation is an approximation to ideal sinc interpolation,
164 | by windowing a sinc kernel with another sinc function extending up to a
165 | few nunber of its lobes (typically a=3).
166 |
167 | Args:
168 | img : tensor (batch_size, channels, height, width), the images to be shifted
169 | shift : tensor (batch_size, 2) of translation parameters (dy, dx)
170 | p : int, padding width prior to convolution (default=3)
171 | a : int, number of lobes in the Lanczos interpolation kernel (default=3)
172 | Returns:
173 | I_s: tensor (batch_size, channels, height, width), shifted images
174 | '''
175 | img = img.transpose(0, 1)
176 | dtype = img.dtype
177 |
178 | if len(img.shape) == 2:
179 | img = img[None, None].repeat(1, shift.shape[0], 1, 1) # batch of one image
180 | elif len(img.shape) == 3: # one image per shift
181 | assert img.shape[0] == shift.shape[0]
182 | img = img[None,]
183 |
184 | # Apply padding
185 |
186 | padder = torch.nn.ReflectionPad2d(p) # reflect pre-padding
187 | I_padded = padder(img)
188 |
189 | # Create 1D shifting kernels
190 |
191 | y_shift = shift[:, [0]]
192 | x_shift = shift[:, [1]]
193 |
194 | k_y = (lanczos_kernel(y_shift, a=a, N=None, dtype=dtype)
195 | .flip(1) # flip axis of convolution
196 | )[:, None, :, None] # expand dims to get shape (batch, channels, y_kernel, 1)
197 | k_x = (lanczos_kernel(x_shift, a=a, N=None, dtype=dtype)
198 | .flip(1)
199 | )[:, None, None, :] # shape (batch, channels, 1, x_kernel)
200 |
201 | # Apply kernels
202 | # print(I_padded.shape, k_y.shape)
203 | I_s = torch.conv1d(I_padded,
204 | groups=k_y.shape[0],
205 | weight=k_y,
206 | padding=[k_y.shape[2] // 2, 0]) # same padding
207 | I_s = torch.conv1d(I_s,
208 | groups=k_x.shape[0],
209 | weight=k_x,
210 | padding=[0, k_x.shape[3] // 2])
211 |
212 | I_s = I_s[..., p:-p, p:-p] # remove padding
213 |
214 | # print(I_s.shape)
215 | return I_s.transpose(0, 1) # , k.squeeze()
216 |
--------------------------------------------------------------------------------
/model/non_local/network.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | # from lib.non_local_concatenation import NONLocalBlock2D
3 | # from lib.non_local_gaussian import NONLocalBlock2D
4 | from lib.non_local_embedded_gaussian import NONLocalBlock2D
5 | # from lib.non_local_dot_product import NONLocalBlock2D
6 |
7 |
8 | class Network(nn.Module):
9 | def __init__(self):
10 | super(Network, self).__init__()
11 |
12 | self.conv_1 = nn.Sequential(
13 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
14 | nn.BatchNorm2d(32),
15 | nn.ReLU(),
16 | nn.MaxPool2d(2),
17 | )
18 |
19 | self.nl_1 = NONLocalBlock2D(in_channels=32)
20 | self.conv_2 = nn.Sequential(
21 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
22 | nn.BatchNorm2d(64),
23 | nn.ReLU(),
24 | nn.MaxPool2d(2),
25 | )
26 |
27 | self.nl_2 = NONLocalBlock2D(in_channels=64)
28 | self.conv_3 = nn.Sequential(
29 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
30 | nn.BatchNorm2d(128),
31 | nn.ReLU(),
32 | nn.MaxPool2d(2),
33 | )
34 |
35 | self.fc = nn.Sequential(
36 | nn.Linear(in_features=128*3*3, out_features=256),
37 | nn.ReLU(),
38 | nn.Dropout(0.5),
39 |
40 | nn.Linear(in_features=256, out_features=10)
41 | )
42 |
43 | def forward(self, x):
44 | batch_size = x.size(0)
45 |
46 | feature_1 = self.conv_1(x)
47 | nl_feature_1 = self.nl_1(feature_1)
48 |
49 | feature_2 = self.conv_2(nl_feature_1)
50 | nl_feature_2 = self.nl_2(feature_2)
51 |
52 | output = self.conv_3(nl_feature_2).view(batch_size, -1)
53 | output = self.fc(output)
54 |
55 | return output
56 |
57 | def forward_with_nl_map(self, x):
58 | batch_size = x.size(0)
59 |
60 | feature_1 = self.conv_1(x)
61 | nl_feature_1, nl_map_1 = self.nl_1(feature_1, return_nl_map=True)
62 |
63 | feature_2 = self.conv_2(nl_feature_1)
64 | nl_feature_2, nl_map_2 = self.nl_2(feature_2, return_nl_map=True)
65 |
66 | output = self.conv_3(nl_feature_2).view(batch_size, -1)
67 | output = self.fc(output)
68 |
69 | return output, [nl_map_1, nl_map_2]
70 |
71 |
72 | if __name__ == '__main__':
73 | import torch
74 |
75 | img = torch.randn(3, 1, 28, 28)
76 | net = Network()
77 | out = net(img)
78 | print(out.size())
79 |
80 |
--------------------------------------------------------------------------------
/model/non_local/non_local_concatenation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class _NonLocalBlockND(nn.Module):
7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
8 | super(_NonLocalBlockND, self).__init__()
9 |
10 | assert dimension in [1, 2, 3]
11 |
12 | self.dimension = dimension
13 | self.sub_sample = sub_sample
14 |
15 | self.in_channels = in_channels
16 | self.inter_channels = inter_channels
17 |
18 | if self.inter_channels is None:
19 | self.inter_channels = in_channels // 2
20 | if self.inter_channels == 0:
21 | self.inter_channels = 1
22 |
23 | if dimension == 3:
24 | conv_nd = nn.Conv3d
25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
26 | bn = nn.BatchNorm3d
27 | elif dimension == 2:
28 | conv_nd = nn.Conv2d
29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
30 | bn = nn.BatchNorm2d
31 | else:
32 | conv_nd = nn.Conv1d
33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
34 | bn = nn.BatchNorm1d
35 |
36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
37 | kernel_size=1, stride=1, padding=0)
38 |
39 | if bn_layer:
40 | self.W = nn.Sequential(
41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
42 | kernel_size=1, stride=1, padding=0),
43 | bn(self.in_channels)
44 | )
45 | nn.init.constant_(self.W[1].weight, 0)
46 | nn.init.constant_(self.W[1].bias, 0)
47 | else:
48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49 | kernel_size=1, stride=1, padding=0)
50 | nn.init.constant_(self.W.weight, 0)
51 | nn.init.constant_(self.W.bias, 0)
52 |
53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
54 | kernel_size=1, stride=1, padding=0)
55 |
56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
57 | kernel_size=1, stride=1, padding=0)
58 |
59 | self.concat_project = nn.Sequential(
60 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
61 | nn.ReLU()
62 | )
63 |
64 | if sub_sample:
65 | self.g = nn.Sequential(self.g, max_pool_layer)
66 | self.phi = nn.Sequential(self.phi, max_pool_layer)
67 |
68 | def forward(self, x, return_nl_map=False):
69 | '''
70 | :param x: (b, c, t, h, w)
71 | :param return_nl_map: if True return z, nl_map, else only return z.
72 | :return:
73 | '''
74 |
75 | batch_size = x.size(0)
76 |
77 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
78 | g_x = g_x.permute(0, 2, 1)
79 |
80 | # (b, c, N, 1)
81 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
82 | # (b, c, 1, N)
83 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
84 |
85 | h = theta_x.size(2)
86 | w = phi_x.size(3)
87 | theta_x = theta_x.repeat(1, 1, 1, w)
88 | phi_x = phi_x.repeat(1, 1, h, 1)
89 |
90 | concat_feature = torch.cat([theta_x, phi_x], dim=1)
91 | f = self.concat_project(concat_feature)
92 | b, _, h, w = f.size()
93 | f = f.view(b, h, w)
94 |
95 | N = f.size(-1)
96 | f_div_C = f / N
97 |
98 | y = torch.matmul(f_div_C, g_x)
99 | y = y.permute(0, 2, 1).contiguous()
100 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
101 | W_y = self.W(y)
102 | z = W_y + x
103 |
104 | if return_nl_map:
105 | return z, f_div_C
106 | return z
107 |
108 |
109 | class NONLocalBlock1D(_NonLocalBlockND):
110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
111 | super(NONLocalBlock1D, self).__init__(in_channels,
112 | inter_channels=inter_channels,
113 | dimension=1, sub_sample=sub_sample,
114 | bn_layer=bn_layer)
115 |
116 |
117 | class NONLocalBlock2D(_NonLocalBlockND):
118 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
119 | super(NONLocalBlock2D, self).__init__(in_channels,
120 | inter_channels=inter_channels,
121 | dimension=2, sub_sample=sub_sample,
122 | bn_layer=bn_layer)
123 |
124 |
125 | class NONLocalBlock3D(_NonLocalBlockND):
126 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,):
127 | super(NONLocalBlock3D, self).__init__(in_channels,
128 | inter_channels=inter_channels,
129 | dimension=3, sub_sample=sub_sample,
130 | bn_layer=bn_layer)
131 |
132 |
133 | if __name__ == '__main__':
134 | import torch
135 |
136 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
137 | img = torch.zeros(2, 3, 20)
138 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
139 | out = net(img)
140 | print(out.size())
141 |
142 | img = torch.zeros(2, 3, 20, 20)
143 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
144 | out = net(img)
145 | print(out.size())
146 |
147 | img = torch.randn(2, 3, 8, 20, 20)
148 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
149 | out = net(img)
150 | print(out.size())
151 |
--------------------------------------------------------------------------------
/model/non_local/non_local_cross_dot_product.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class _NonLocalBlockND(nn.Module):
7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
8 | super(_NonLocalBlockND, self).__init__()
9 |
10 | assert dimension in [1, 2, 3]
11 |
12 | self.dimension = dimension
13 | self.sub_sample = sub_sample
14 |
15 | self.in_channels = in_channels
16 | self.inter_channels = inter_channels
17 |
18 | if self.inter_channels is None:
19 | self.inter_channels = in_channels // 2
20 | if self.inter_channels == 0:
21 | self.inter_channels = 1
22 |
23 | if dimension == 3:
24 | conv_nd = nn.Conv3d
25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4))
26 | bn = nn.BatchNorm3d
27 | elif dimension == 2:
28 | conv_nd = nn.Conv2d
29 | max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4))
30 | bn = nn.BatchNorm2d
31 | else:
32 | conv_nd = nn.Conv1d
33 | max_pool_layer = nn.MaxPool1d(kernel_size=(4))
34 | bn = nn.BatchNorm1d
35 |
36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
37 | kernel_size=1, stride=1, padding=0)
38 |
39 | if bn_layer:
40 | self.W = nn.Sequential(
41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
42 | kernel_size=1, stride=1, padding=0),
43 | bn(self.in_channels)
44 | )
45 | nn.init.constant_(self.W[1].weight, 0)
46 | nn.init.constant_(self.W[1].bias, 0)
47 | else:
48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49 | kernel_size=1, stride=1, padding=0)
50 | nn.init.constant_(self.W.weight, 0)
51 | nn.init.constant_(self.W.bias, 0)
52 |
53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
54 | kernel_size=1, stride=1, padding=0)
55 |
56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
57 | kernel_size=1, stride=1, padding=0)
58 |
59 | if sub_sample:
60 | self.g = nn.Sequential(self.g, max_pool_layer)
61 | self.phi = nn.Sequential(self.phi, max_pool_layer)
62 |
63 | def forward(self, x, ref, return_nl_map=False):
64 | """
65 | :param x: (b, c, t, h, w)
66 | :param return_nl_map: if True return z, nl_map, else only return z.
67 | :return:
68 | """
69 |
70 | batch_size = x.size(0)
71 |
72 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
73 | g_x = g_x.permute(0, 2, 1)
74 |
75 | theta_ref = self.theta(ref).view(batch_size, self.inter_channels, -1)
76 | theta_ref = theta_ref.permute(0, 2, 1)
77 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
78 | f = torch.matmul(theta_ref, phi_x)
79 | N = f.size(-1)
80 | f_div_C = f / N
81 |
82 | y = torch.matmul(f_div_C, g_x)
83 | y = y.permute(0, 2, 1).contiguous()
84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
85 | W_y = self.W(y)
86 | z = W_y + x
87 |
88 | if return_nl_map:
89 | return z, f_div_C
90 | return z
91 |
92 |
93 | class NONLocalBlock1D(_NonLocalBlockND):
94 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
95 | super(NONLocalBlock1D, self).__init__(in_channels,
96 | inter_channels=inter_channels,
97 | dimension=1, sub_sample=sub_sample,
98 | bn_layer=bn_layer)
99 |
100 |
101 | class NONLocalBlock2D(_NonLocalBlockND):
102 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
103 | super(NONLocalBlock2D, self).__init__(in_channels,
104 | inter_channels=inter_channels,
105 | dimension=2, sub_sample=sub_sample,
106 | bn_layer=bn_layer)
107 |
108 |
109 | class NONLocalBlock3D(_NonLocalBlockND):
110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
111 | super(NONLocalBlock3D, self).__init__(in_channels,
112 | inter_channels=inter_channels,
113 | dimension=3, sub_sample=sub_sample,
114 | bn_layer=bn_layer)
115 |
116 |
117 | if __name__ == '__main__':
118 | import torch
119 |
120 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
121 | img = torch.zeros(2, 3, 20)
122 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
123 | out = net(img)
124 | print(out.size())
125 |
126 | img = torch.zeros(2, 3, 20, 20)
127 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
128 | out = net(img)
129 | print(out.size())
130 |
131 | img = torch.randn(2, 3, 8, 20, 20)
132 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
133 | out = net(img)
134 | print(out.size())
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/model/non_local/non_local_dot_product.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class _NonLocalBlockND(nn.Module):
7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
8 | super(_NonLocalBlockND, self).__init__()
9 |
10 | assert dimension in [1, 2, 3]
11 |
12 | self.dimension = dimension
13 | self.sub_sample = sub_sample
14 |
15 | self.in_channels = in_channels
16 | self.inter_channels = inter_channels
17 |
18 | if self.inter_channels is None:
19 | self.inter_channels = in_channels // 2
20 | if self.inter_channels == 0:
21 | self.inter_channels = 1
22 |
23 | if dimension == 3:
24 | conv_nd = nn.Conv3d
25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4))
26 | bn = nn.BatchNorm3d
27 | elif dimension == 2:
28 | conv_nd = nn.Conv2d
29 | max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4))
30 | bn = nn.BatchNorm2d
31 | else:
32 | conv_nd = nn.Conv1d
33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
34 | bn = nn.BatchNorm1d
35 |
36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
37 | kernel_size=1, stride=1, padding=0)
38 |
39 | if bn_layer:
40 | self.W = nn.Sequential(
41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
42 | kernel_size=1, stride=1, padding=0),
43 | bn(self.in_channels)
44 | )
45 | nn.init.constant_(self.W[1].weight, 0)
46 | nn.init.constant_(self.W[1].bias, 0)
47 | else:
48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49 | kernel_size=1, stride=1, padding=0)
50 | nn.init.constant_(self.W.weight, 0)
51 | nn.init.constant_(self.W.bias, 0)
52 |
53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
54 | kernel_size=1, stride=1, padding=0)
55 |
56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
57 | kernel_size=1, stride=1, padding=0)
58 |
59 | if sub_sample:
60 | self.g = nn.Sequential(self.g, max_pool_layer)
61 | self.phi = nn.Sequential(self.phi, max_pool_layer)
62 |
63 | def forward(self, x, return_nl_map=False):
64 | """
65 | :param x: (b, c, t, h, w)
66 | :param return_nl_map: if True return z, nl_map, else only return z.
67 | :return:
68 | """
69 |
70 | batch_size = x.size(0)
71 |
72 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
73 | g_x = g_x.permute(0, 2, 1)
74 |
75 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
76 | theta_x = theta_x.permute(0, 2, 1)
77 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
78 | f = torch.matmul(theta_x, phi_x)
79 | N = f.size(-1)
80 | f_div_C = f / N
81 |
82 | y = torch.matmul(f_div_C, g_x)
83 | y = y.permute(0, 2, 1).contiguous()
84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
85 | W_y = self.W(y)
86 | z = W_y + x
87 |
88 | if return_nl_map:
89 | return z, f_div_C
90 | return z
91 |
92 |
93 | class NONLocalBlock1D(_NonLocalBlockND):
94 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
95 | super(NONLocalBlock1D, self).__init__(in_channels,
96 | inter_channels=inter_channels,
97 | dimension=1, sub_sample=sub_sample,
98 | bn_layer=bn_layer)
99 |
100 |
101 | class NONLocalBlock2D(_NonLocalBlockND):
102 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
103 | super(NONLocalBlock2D, self).__init__(in_channels,
104 | inter_channels=inter_channels,
105 | dimension=2, sub_sample=sub_sample,
106 | bn_layer=bn_layer)
107 |
108 |
109 | class NONLocalBlock3D(_NonLocalBlockND):
110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
111 | super(NONLocalBlock3D, self).__init__(in_channels,
112 | inter_channels=inter_channels,
113 | dimension=3, sub_sample=sub_sample,
114 | bn_layer=bn_layer)
115 |
116 |
117 | if __name__ == '__main__':
118 | import torch
119 |
120 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
121 | img = torch.zeros(2, 3, 20)
122 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
123 | out = net(img)
124 | print(out.size())
125 |
126 | img = torch.zeros(2, 3, 20, 20)
127 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
128 | out = net(img)
129 | print(out.size())
130 |
131 | img = torch.randn(2, 3, 8, 20, 20)
132 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
133 | out = net(img)
134 | print(out.size())
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/model/non_local/non_local_embedded_gaussian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class _NonLocalBlockND(nn.Module):
7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
8 | """
9 | :param in_channels:
10 | :param inter_channels:
11 | :param dimension:
12 | :param sub_sample:
13 | :param bn_layer:
14 | """
15 |
16 | super(_NonLocalBlockND, self).__init__()
17 |
18 | assert dimension in [1, 2, 3]
19 |
20 | self.dimension = dimension
21 | self.sub_sample = sub_sample
22 |
23 | self.in_channels = in_channels
24 | self.inter_channels = inter_channels
25 |
26 | if self.inter_channels is None:
27 | self.inter_channels = in_channels // 2
28 | if self.inter_channels == 0:
29 | self.inter_channels = 1
30 |
31 | if dimension == 3:
32 | conv_nd = nn.Conv3d
33 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
34 | bn = nn.BatchNorm3d
35 | elif dimension == 2:
36 | conv_nd = nn.Conv2d
37 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
38 | bn = nn.BatchNorm2d
39 | else:
40 | conv_nd = nn.Conv1d
41 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
42 | bn = nn.BatchNorm1d
43 |
44 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
45 | kernel_size=1, stride=1, padding=0)
46 |
47 | if bn_layer:
48 | self.W = nn.Sequential(
49 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
50 | kernel_size=1, stride=1, padding=0),
51 | bn(self.in_channels)
52 | )
53 | nn.init.constant_(self.W[1].weight, 0)
54 | nn.init.constant_(self.W[1].bias, 0)
55 | else:
56 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
57 | kernel_size=1, stride=1, padding=0)
58 | nn.init.constant_(self.W.weight, 0)
59 | nn.init.constant_(self.W.bias, 0)
60 |
61 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
62 | kernel_size=1, stride=1, padding=0)
63 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
64 | kernel_size=1, stride=1, padding=0)
65 |
66 | if sub_sample:
67 | self.g = nn.Sequential(self.g, max_pool_layer)
68 | self.phi = nn.Sequential(self.phi, max_pool_layer)
69 |
70 | def forward(self, x, return_nl_map=False):
71 | """
72 | :param x: (b, c, t, h, w)
73 | :param return_nl_map: if True return z, nl_map, else only return z.
74 | :return:
75 | """
76 |
77 | batch_size = x.size(0)
78 |
79 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
80 | g_x = g_x.permute(0, 2, 1)
81 |
82 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
83 | theta_x = theta_x.permute(0, 2, 1)
84 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
85 | f = torch.matmul(theta_x, phi_x)
86 | f_div_C = F.softmax(f, dim=-1)
87 |
88 | y = torch.matmul(f_div_C, g_x)
89 | y = y.permute(0, 2, 1).contiguous()
90 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
91 | W_y = self.W(y)
92 | z = W_y + x
93 |
94 | if return_nl_map:
95 | return z, f_div_C
96 | return z
97 |
98 |
99 | class NONLocalBlock1D(_NonLocalBlockND):
100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
101 | super(NONLocalBlock1D, self).__init__(in_channels,
102 | inter_channels=inter_channels,
103 | dimension=1, sub_sample=sub_sample,
104 | bn_layer=bn_layer)
105 |
106 |
107 | class NONLocalBlock2D(_NonLocalBlockND):
108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
109 | super(NONLocalBlock2D, self).__init__(in_channels,
110 | inter_channels=inter_channels,
111 | dimension=2, sub_sample=sub_sample,
112 | bn_layer=bn_layer,)
113 |
114 |
115 | class NONLocalBlock3D(_NonLocalBlockND):
116 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
117 | super(NONLocalBlock3D, self).__init__(in_channels,
118 | inter_channels=inter_channels,
119 | dimension=3, sub_sample=sub_sample,
120 | bn_layer=bn_layer,)
121 |
122 |
123 | if __name__ == '__main__':
124 | import torch
125 |
126 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
127 | img = torch.zeros(2, 3, 20)
128 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
129 | out = net(img)
130 | print(out.size())
131 |
132 | img = torch.zeros(2, 3, 20, 20)
133 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
134 | out = net(img)
135 | print(out.size())
136 |
137 | img = torch.randn(2, 3, 8, 20, 20)
138 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
139 | out = net(img)
140 | print(out.size())
141 |
142 |
143 |
--------------------------------------------------------------------------------
/model/non_local/non_local_gaussian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class _NonLocalBlockND(nn.Module):
7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
8 | super(_NonLocalBlockND, self).__init__()
9 |
10 | assert dimension in [1, 2, 3]
11 |
12 | self.dimension = dimension
13 | self.sub_sample = sub_sample
14 |
15 | self.in_channels = in_channels
16 | self.inter_channels = inter_channels
17 |
18 | if self.inter_channels is None:
19 | self.inter_channels = in_channels // 2
20 | if self.inter_channels == 0:
21 | self.inter_channels = 1
22 |
23 | if dimension == 3:
24 | conv_nd = nn.Conv3d
25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
26 | bn = nn.BatchNorm3d
27 | elif dimension == 2:
28 | conv_nd = nn.Conv2d
29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
30 | bn = nn.BatchNorm2d
31 | else:
32 | conv_nd = nn.Conv1d
33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
34 | bn = nn.BatchNorm1d
35 |
36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
37 | kernel_size=1, stride=1, padding=0)
38 |
39 | if bn_layer:
40 | self.W = nn.Sequential(
41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
42 | kernel_size=1, stride=1, padding=0),
43 | bn(self.in_channels)
44 | )
45 | nn.init.constant_(self.W[1].weight, 0)
46 | nn.init.constant_(self.W[1].bias, 0)
47 | else:
48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49 | kernel_size=1, stride=1, padding=0)
50 | nn.init.constant_(self.W.weight, 0)
51 | nn.init.constant_(self.W.bias, 0)
52 |
53 | if sub_sample:
54 | self.g = nn.Sequential(self.g, max_pool_layer)
55 | self.phi = max_pool_layer
56 |
57 | def forward(self, x, return_nl_map=False):
58 | """
59 | :param x: (b, c, t, h, w)
60 | :param return_nl_map: if True return z, nl_map, else only return z.
61 | :return:
62 | """
63 |
64 | batch_size = x.size(0)
65 |
66 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
67 |
68 | g_x = g_x.permute(0, 2, 1)
69 |
70 | theta_x = x.view(batch_size, self.in_channels, -1)
71 | theta_x = theta_x.permute(0, 2, 1)
72 |
73 | if self.sub_sample:
74 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
75 | else:
76 | phi_x = x.view(batch_size, self.in_channels, -1)
77 |
78 | f = torch.matmul(theta_x, phi_x)
79 | f_div_C = F.softmax(f, dim=-1)
80 |
81 | # if self.store_last_batch_nl_map:
82 | # self.nl_map = f_div_C
83 |
84 | y = torch.matmul(f_div_C, g_x)
85 | y = y.permute(0, 2, 1).contiguous()
86 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
87 | W_y = self.W(y)
88 | z = W_y + x
89 |
90 | if return_nl_map:
91 | return z, f_div_C
92 | return z
93 |
94 |
95 | class NONLocalBlock1D(_NonLocalBlockND):
96 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
97 | super(NONLocalBlock1D, self).__init__(in_channels,
98 | inter_channels=inter_channels,
99 | dimension=1, sub_sample=sub_sample,
100 | bn_layer=bn_layer)
101 |
102 |
103 | class NONLocalBlock2D(_NonLocalBlockND):
104 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
105 | super(NONLocalBlock2D, self).__init__(in_channels,
106 | inter_channels=inter_channels,
107 | dimension=2, sub_sample=sub_sample,
108 | bn_layer=bn_layer)
109 |
110 |
111 | class NONLocalBlock3D(_NonLocalBlockND):
112 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
113 | super(NONLocalBlock3D, self).__init__(in_channels,
114 | inter_channels=inter_channels,
115 | dimension=3, sub_sample=sub_sample,
116 | bn_layer=bn_layer)
117 |
118 |
119 | if __name__ == '__main__':
120 | import torch
121 |
122 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
123 | img = torch.zeros(2, 3, 20)
124 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
125 | out = net(img)
126 | print(out.size())
127 |
128 | img = torch.zeros(2, 3, 20, 20)
129 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
130 | out = net(img)
131 | print(out.size())
132 |
133 | img = torch.randn(2, 3, 8, 20, 20)
134 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
135 | out = net(img)
136 | print(out.size())
137 |
138 |
139 |
140 |
141 |
142 |
143 |
--------------------------------------------------------------------------------
/model/utils/interp_methods.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 |
3 | try:
4 | import torch
5 | except ImportError:
6 | torch = None
7 |
8 | try:
9 | import numpy
10 | except ImportError:
11 | numpy = None
12 |
13 | if numpy is None and torch is None:
14 | raise ImportError("Must have either Numpy or PyTorch but both not found")
15 |
16 |
17 | def set_framework_dependencies(x):
18 | if type(x) is numpy.ndarray:
19 | to_dtype = lambda a: a
20 | fw = numpy
21 | else:
22 | to_dtype = lambda a: a.to(x.dtype)
23 | fw = torch
24 | eps = fw.finfo(fw.float32).eps
25 | return fw, to_dtype, eps
26 |
27 |
28 | def support_sz(sz):
29 | def wrapper(f):
30 | f.support_sz = sz
31 | return f
32 | return wrapper
33 |
34 | @support_sz(4)
35 | def cubic(x):
36 | fw, to_dtype, eps = set_framework_dependencies(x)
37 | absx = fw.abs(x)
38 | absx2 = absx ** 2
39 | absx3 = absx ** 3
40 | return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
41 | (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *
42 | to_dtype((1. < absx) & (absx <= 2.)))
43 |
44 | @support_sz(4)
45 | def lanczos2(x):
46 | fw, to_dtype, eps = set_framework_dependencies(x)
47 | return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /
48 | ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))
49 |
50 | @support_sz(6)
51 | def lanczos3(x):
52 | fw, to_dtype, eps = set_framework_dependencies(x)
53 | return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /
54 | ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))
55 |
56 | @support_sz(2)
57 | def linear(x):
58 | fw, to_dtype, eps = set_framework_dependencies(x)
59 | return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *
60 | to_dtype((0 <= x) & (x <= 1)))
61 |
62 | @support_sz(1)
63 | def box(x):
64 | fw, to_dtype, eps = set_framework_dependencies(x)
65 | return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))
66 |
--------------------------------------------------------------------------------
/model/utils/psconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class PyConv2d(nn.Module):
5 | """PyConv2d with padding (general case). Applies a 2D PyConv over an input signal composed of several input planes.
6 | Args:
7 | in_channels (int): Number of channels in the input image
8 | out_channels (list): Number of channels for each pyramid level produced by the convolution
9 | pyconv_kernels (list): Spatial size of the kernel for each pyramid level
10 | pyconv_groups (list): Number of blocked connections from input channels to output channels for each pyramid level
11 | stride (int or tuple, optional): Stride of the convolution. Default: 1
12 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
13 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``
14 | Example::
15 | >>> # PyConv with two pyramid levels, kernels: 3x3, 5x5
16 | >>> m = PyConv2d(in_channels=64, out_channels=[32, 32], pyconv_kernels=[3, 5], pyconv_groups=[1, 4])
17 | >>> input = torch.randn(4, 64, 56, 56)
18 | >>> output = m(input)
19 | >>> # PyConv with three pyramid levels, kernels: 3x3, 5x5, 7x7
20 | >>> m = PyConv2d(in_channels=64, out_channels=[16, 16, 32], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])
21 | >>> input = torch.randn(4, 64, 56, 56)
22 | >>> output = m(input)
23 | """
24 | def __init__(self, in_channels, out_channels, pyconv_kernels, pyconv_groups, stride=1, dilation=1, bias=False):
25 | super(PyConv2d, self).__init__()
26 |
27 | assert len(out_channels) == len(pyconv_kernels) == len(pyconv_groups)
28 |
29 | self.pyconv_levels = [None] * len(pyconv_kernels)
30 | for i in range(len(pyconv_kernels)):
31 | self.pyconv_levels[i] = nn.Conv2d(in_channels, out_channels[i], kernel_size=pyconv_kernels[i],
32 | stride=stride, padding=pyconv_kernels[i] // 2, groups=pyconv_groups[i],
33 | dilation=dilation, bias=bias)
34 | self.pyconv_levels = nn.ModuleList(self.pyconv_levels)
35 |
36 | def forward(self, x):
37 | out = []
38 | for level in self.pyconv_levels:
39 | out.append(level(x))
40 |
41 | return torch.cat(out, 1)
42 |
43 | ################################################################
44 |
45 | class PSConv2d(nn.Module):
46 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, parts=4, bias=False):
47 | super(PSConv2d, self).__init__()
48 | self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, dilation, dilation, groups=parts, bias=bias)
49 | self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * dilation, 2 * dilation, groups=parts, bias=bias)
50 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
51 |
52 | def backward_hook(grad):
53 | out = grad.clone()
54 | out[self.mask] = 0
55 | return out
56 |
57 | self.mask = torch.zeros(self.conv.weight.shape).byte().cuda()
58 | _in_channels = in_channels // parts
59 | _out_channels = out_channels // parts
60 | for i in range(parts):
61 | self.mask[i * _out_channels: (i + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1
62 | self.mask[(i + parts//2)%parts * _out_channels: ((i + parts//2)%parts + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1
63 | self.conv.weight.data[self.mask] = 0
64 | self.conv.weight.register_hook(backward_hook)
65 |
66 | self.weight = self.conv.weight
67 | self.bias = self.conv.bias
68 |
69 | def forward(self, x):
70 | x1, x2 = x.chunk(2, dim=1)
71 | x_shift = self.gwconv_shift(torch.cat((x2, x1), dim=1))
72 | return self.gwconv(x) + self.conv(x) + x_shift
73 |
74 |
75 | # PSConv-based Group Convolution
76 | class PSGConv2d(nn.Module):
77 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, parts=4, bias=False):
78 | super(PSGConv2d, self).__init__()
79 | self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups * parts, bias=bias)
80 | self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * padding, 2 * dilation, groups=groups * parts, bias=bias)
81 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias)
82 |
83 | def backward_hook(grad):
84 | out = grad.clone()
85 | out[self.mask] = 0
86 | return out
87 |
88 | self.mask = torch.zeros(self.conv.weight.shape).bool().cuda()
89 | _in_channels = in_channels // (groups * parts)
90 | _out_channels = out_channels // (groups * parts)
91 | for i in range(parts):
92 | for j in range(groups):
93 | self.mask[(i + j * groups) * _out_channels: (i + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1
94 | self.mask[((i + parts // 2) % parts + j * groups) * _out_channels: ((i + parts // 2) % parts + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1
95 | self.conv.weight.data[self.mask] = 0
96 | self.conv.weight.register_hook(backward_hook)
97 | self.groups = groups
98 |
99 | self.weight = self.conv.weight
100 | self.bias = self.conv.bias
101 |
102 | def forward(self, x):
103 | x_split = (z.chunk(2, dim=1) for z in x.chunk(self.groups, dim=1))
104 | x_merge = torch.cat(tuple(torch.cat((x2, x1), dim=1) for (x1, x2) in x_split), dim=1)
105 | x_shift = self.gwconv_shift(x_merge)
106 | gx = self.gwconv(x)
107 | cx = self.conv(x)
108 | # print(x.shape, gx.shape, cx.shape, x_merge.shape, x_shift.shape)
109 | return gx + cx + x_shift
110 |
111 |
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser(description='EDSR and MDSR')
4 |
5 | parser.add_argument('--n_resblocks', type=int, default=16,
6 | help='number of residual blocks')
7 | parser.add_argument('--n_feats', type=int, default=64,
8 | help='number of feature maps')
9 | parser.add_argument('--n_colors', type=int, default=3,
10 | help='number of color channels to use')
11 | parser.add_argument('--lr', type=float, default=1e-4,
12 | help='learning rate')
13 | parser.add_argument('--burst_size', type=int, default=14,
14 | help='burst size, max 14')
15 | parser.add_argument('--burst_channel', type=int, default=1,
16 | help='burst size, max 14')
17 | parser.add_argument('--sift_lr', action='store_true',
18 | help='use sift to pre-align burst frames')
19 | parser.add_argument('--lrcn', action='store_true',
20 | help='use long-range concatenating network')
21 |
22 | # Hardware specifications
23 | parser.add_argument('--n_threads', type=int, default=6,
24 | help='number of threads for data loading')
25 | parser.add_argument('--cpu', action='store_true',
26 | help='use cpu only')
27 | parser.add_argument('--n_GPUs', type=int, default=2,
28 | help='number of GPUs')
29 | parser.add_argument('--seed', type=int, default=1,
30 | help='random seed')
31 | parser.add_argument('--local_rank', type=int, default=-1,
32 | help='proc index')
33 | parser.add_argument('--fp16', action='store_true',
34 | help='use fp16 only')
35 | parser.add_argument('--load_head', action='store_true',
36 | help='load head from other model')
37 | parser.add_argument('--load_sr', action='store_true',
38 | help='load sr module from other model')
39 | parser.add_argument('--finetune_head', action='store_true',
40 | help='load head from other model')
41 | parser.add_argument('--finetune_large', action='store_true',
42 | help='load head from other model')
43 | parser.add_argument('--finetune_large_skip', action='store_true',
44 | help='load head from other model')
45 | parser.add_argument('--finetune_pcd', action='store_true',
46 | help='load head from other model')
47 | parser.add_argument('--use_tree', action='store_true',
48 | help='load head from other model')
49 |
50 | # Data specifications
51 | parser.add_argument('--root', type=str, default='/data/dataset/ntire21/burstsr/synthetic',
52 | help='dataset directory')
53 | parser.add_argument('--mode', type=str, default='train',
54 | help='demo image directory')
55 | parser.add_argument('--scale', type=str, default='4',
56 | help='super resolution scale')
57 | parser.add_argument('--patch_size', type=int, default=256,
58 | help='output patch size')
59 | parser.add_argument('--rgb_range', type=int, default=1,
60 | help='maximum value of RGB')
61 |
62 | parser.add_argument('--chop', action='store_true',
63 | help='enable memory-efficient forward')
64 | parser.add_argument('--no_augment', action='store_true',
65 | help='do not use data augmentation')
66 |
67 | # Model specifications
68 | parser.add_argument('--model', default='LRSC_EDVR',
69 | help='model name')
70 |
71 | parser.add_argument('--act', type=str, default='relu',
72 | help='activation function')
73 | parser.add_argument('--pre_train', type=str, default='',
74 | help='pre-trained model directory')
75 | parser.add_argument('--extend', type=str, default='.',
76 | help='pre-trained model directory')
77 |
78 | parser.add_argument('--res_scale', type=float, default=1,
79 | help='residual scaling')
80 | parser.add_argument('--shift_mean', default=True,
81 | help='subtract pixel mean from the input')
82 | parser.add_argument('--dilation', action='store_true',
83 | help='use dilated convolution')
84 | parser.add_argument('--precision', type=str, default='single',
85 | choices=('single', 'half'),
86 | help='FP precision for test (single | half)')
87 |
88 |
89 | # Option for Residual channel attention network (RCAN)
90 | parser.add_argument('--n_resgroups', type=int, default=20,
91 | help='number of residual groups')
92 | parser.add_argument('--reduction', type=int, default=16,
93 | help='number of feature maps reduction')
94 | parser.add_argument('--DA', action='store_true',
95 | help='use Dual Attention')
96 | parser.add_argument('--CA', action='store_true',
97 | help='use Channel Attention')
98 | parser.add_argument('--non_local', action='store_true',
99 | help='use Dual Attention')
100 |
101 | # Training specifications
102 | parser.add_argument('--reset', action='store_true',
103 | help='reset the training')
104 | parser.add_argument('--test_every', type=int, default=1000,
105 | help='do test per every N batches')
106 | parser.add_argument('--epochs', type=int, default=602,
107 | help='number of epochs to train')
108 | parser.add_argument('--batch_size', type=int, default=8,
109 | help='input batch size for training')
110 | parser.add_argument('--split_batch', type=int, default=1,
111 | help='split the batch into smaller chunks')
112 | parser.add_argument('--self_ensemble', action='store_true',
113 | help='use self-ensemble method for test')
114 | parser.add_argument('--test_only', action='store_true',
115 | help='set this option to test the model')
116 | parser.add_argument('--gan_k', type=int, default=1,
117 | help='k value for adversarial loss')
118 |
119 | # Optimization specifications
120 |
121 | parser.add_argument('--decay', type=str, default='150-250',
122 | help='learning rate decay type')
123 | parser.add_argument('--gamma', type=float, default=0.5,
124 | help='learning rate decay factor for step decay')
125 | parser.add_argument('--optimizer', default='ADAM',
126 | choices=('SGD', 'ADAM', 'RMSprop'),
127 | help='optimizer to use (SGD | ADAM | RMSprop)')
128 | parser.add_argument('--momentum', type=float, default=0.9,
129 | help='SGD momentum')
130 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
131 | help='ADAM beta')
132 | parser.add_argument('--epsilon', type=float, default=1e-8,
133 | help='ADAM epsilon for numerical stability')
134 | parser.add_argument('--weight_decay', type=float, default=0,
135 | help='weight decay')
136 | parser.add_argument('--gclip', type=float, default=0,
137 | help='gradient clipping threshold (0 = no clipping)')
138 |
139 | # Loss specifications
140 | parser.add_argument('--loss', type=str, default='1*L1',
141 | help='loss function configuration')
142 | parser.add_argument('--skip_threshold', type=float, default='1e8',
143 | help='skipping batch that has large error')
144 |
145 | # Log specifications
146 | parser.add_argument('--save', type=str, default='test',
147 | help='file name to save')
148 | parser.add_argument('--load', type=str, default='',
149 | help='file name to load')
150 | parser.add_argument('--resume', type=int, default=0,
151 | help='resume from specific checkpoint')
152 | parser.add_argument('--save_models', action='store_true',
153 | help='save all intermediate models')
154 | parser.add_argument('--print_every', type=int, default=1,
155 | help='how many batches to wait before logging training status')
156 | parser.add_argument('--save_results', action='store_true',
157 | help='save output results')
158 | parser.add_argument('--save_gt', action='store_true',
159 | help='save low-resolution and high-resolution images together')
160 |
161 | args = parser.parse_args()
162 |
163 | args.scale = list(map(lambda x: int(x), args.scale.split('+')))
164 |
165 | if args.epochs == 0:
166 | args.epochs = 1e8
167 |
168 | for arg in vars(args):
169 | if vars(args)[arg] == 'True':
170 | vars(args)[arg] = True
171 | elif vars(args)[arg] == 'False':
172 | vars(args)[arg] = False
173 |
174 |
--------------------------------------------------------------------------------
/pwcnet/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-pwc
2 | This is a personal reimplementation of PWC-Net [1] using PyTorch. Should you be making use of this work, please cite the paper accordingly. Also, make sure to adhere to the licensing terms of the authors. Should you be making use of this particular implementation, please acknowledge it appropriately [2].
3 |
4 |
5 |
6 | For the original version of this work, please see: https://github.com/NVlabs/PWC-Net
7 |
8 | Another optical flow implementation from me: https://github.com/sniklaus/pytorch-liteflownet
9 |
10 | And another optical flow implementation from me: https://github.com/sniklaus/pytorch-unflow
11 |
12 | Yet another optical flow implementation from me: https://github.com/sniklaus/pytorch-spynet
13 |
14 | ## background
15 | The authors of PWC-Net are thankfully already providing a reference implementation in PyTorch. However, its initial version did not reach the performance of the original Caffe version. This is why I created this repositroy, in which I replicated the performance of the official Caffe version by utilizing its weights.
16 |
17 | The official PyTorch implementation has adopted my approach of using the Caffe weights since then, which is why they are all performing equally well now. Many people have reported issues with CUDA when trying to get the official PyTorch version to run though, while my reimplementaiton does not seem to be subject to such problems.
18 |
19 | ## setup
20 | To download the pre-trained models, run `bash download.bash`. These originate from the original authors, I just converted them to PyTorch.
21 |
22 | The correlation layer is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using `pip install cupy` or alternatively using one of the provided binary packages as outlined in the CuPy repository.
23 |
24 | ## usage
25 | To run it on your own pair of images, use the following command. You can choose between two models, please make sure to see their paper / the code for more details.
26 |
27 | ```
28 | python run.py --model default --first ./images/first.png --second ./images/second.png --out ./out.flo
29 | ```
30 |
31 | I am afraid that I cannot guarantee that this reimplementation is correct. However, it produced results identical to the Caffe implementation of the original authors in the examples that I tried. Please feel free to contribute to this repository by submitting issues and pull requests.
32 |
33 | ## comparison
34 | 
35 |
36 | ## license
37 | As stated in the licensing terms of the authors of the paper, the models are free for non-commercial share-alike purpose. Please make sure to further consult their licensing terms.
38 |
39 | ## references
40 | ```
41 | [1] @inproceedings{Sun_CVPR_2018,
42 | author = {Deqing Sun and Xiaodong Yang and Ming-Yu Liu and Jan Kautz},
43 | title = {{PWC-Net}: {CNNs} for Optical Flow Using Pyramid, Warping, and Cost Volume},
44 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
45 | year = {2018}
46 | }
47 | ```
48 |
49 | ```
50 | [2] @misc{pytorch-pwc,
51 | author = {Simon Niklaus},
52 | title = {A Reimplementation of {PWC-Net} Using {PyTorch}},
53 | year = {2018},
54 | howpublished = {\url{https://github.com/sniklaus/pytorch-pwc}}
55 | }
56 | ```
--------------------------------------------------------------------------------
/pwcnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/__init__.py
--------------------------------------------------------------------------------
/pwcnet/comparison/comparison.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/comparison/comparison.gif
--------------------------------------------------------------------------------
/pwcnet/comparison/comparison.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import math
4 | import moviepy
5 | import moviepy.editor
6 | import numpy
7 | import PIL
8 | import PIL.Image
9 | import PIL.ImageFont
10 | import PIL.ImageDraw
11 |
12 | intX = 32
13 | intY = 436 - 64
14 |
15 | objImages = [ {
16 | 'strFile': 'official - caffe.png',
17 | 'strText': 'official - Caffe'
18 | }, {
19 | 'strFile': 'this - pytorch.png',
20 | 'strText': 'this - PyTorch'
21 | } ]
22 |
23 | npyImages = []
24 |
25 | for objImage in objImages:
26 | objOutput = PIL.Image.open(objImage['strFile']).convert('RGB')
27 |
28 | for intU in [ intShift - 10 for intShift in range(20) ]:
29 | for intV in [ intShift - 10 for intShift in range(20) ]:
30 | if math.sqrt(math.pow(intU, 2.0) + math.pow(intV, 2.0)) <= 5.0:
31 | PIL.ImageDraw.Draw(objOutput).text((intX + intU, intY + intV), objImage['strText'], (255, 255, 255), PIL.ImageFont.truetype('freefont/FreeSerifBold.ttf', 32))
32 | # end
33 | # end
34 | # end
35 |
36 | PIL.ImageDraw.Draw(objOutput).text((intX, intY), objImage['strText'], (0, 0, 0), PIL.ImageFont.truetype('freefont/FreeSerifBold.ttf', 32))
37 |
38 | npyImages.append(numpy.array(objOutput))
39 | # end
40 |
41 | moviepy.editor.ImageSequenceClip(sequence=npyImages, fps=1).write_gif(filename='comparison.gif', program='ImageMagick', opt='optimizeplus')
--------------------------------------------------------------------------------
/pwcnet/comparison/official - caffe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/comparison/official - caffe.png
--------------------------------------------------------------------------------
/pwcnet/comparison/this - pytorch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/comparison/this - pytorch.png
--------------------------------------------------------------------------------
/pwcnet/correlation/README.md:
--------------------------------------------------------------------------------
1 | This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately.
--------------------------------------------------------------------------------
/pwcnet/correlation/__pycache__/correlation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/correlation/__pycache__/correlation.cpython-37.pyc
--------------------------------------------------------------------------------
/pwcnet/download.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget --verbose --continue --timestamping http://content.sniklaus.com/github/pytorch-pwc/network-chairs-things.pytorch
4 | wget --verbose --continue --timestamping http://content.sniklaus.com/github/pytorch-pwc/network-default.pytorch
--------------------------------------------------------------------------------
/pwcnet/images/README.md:
--------------------------------------------------------------------------------
1 | The used example originates from the MPI Sintel dataset: http://sintel.is.tue.mpg.de/
--------------------------------------------------------------------------------
/pwcnet/images/first.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/images/first.png
--------------------------------------------------------------------------------
/pwcnet/images/second.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/images/second.png
--------------------------------------------------------------------------------
/pwcnet/out.flo:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/pwcnet/out.flo
--------------------------------------------------------------------------------
/pwcnet/requirements.txt:
--------------------------------------------------------------------------------
1 | cupy>=5.0.0
2 | numpy>=1.15.0
3 | Pillow>=5.0.0
4 | torch>=1.3.0
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | imageio
3 | opencv-python
4 | tensorboardX
5 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/cal_mean_std.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tqdm import tqdm
4 |
5 | from datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image
6 | from datasets.synthetic_burst_train_set import SyntheticBurst
7 | from datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB
8 |
9 | def main():
10 | train_zurich_raw2rgb = ZurichRAW2RGB(root='/data/dataset/ntire21/burstsr/synthetic', split='train')
11 | train_data = SyntheticBurst(train_zurich_raw2rgb, burst_size=14, crop_sz=384)
12 | means = []
13 | stds = []
14 |
15 | for data in tqdm(train_data):
16 | print(data.shape)
17 | break
18 |
19 |
20 | if __name__ == '__main__':
21 | # if not args.cpu: torch.cuda.set_device(0)
22 | main()
23 |
--------------------------------------------------------------------------------
/scripts/demo.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | rlaunch --cpu=4 --gpu=1 --memory=10240 -- python ./scripts/evaluate_burstsr_val.py
3 |
--------------------------------------------------------------------------------
/scripts/download_burstsr_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import urllib.request
3 | import zipfile
4 | import shutil
5 | import argparse
6 |
7 |
8 | def download_burstsr_dataset(download_path):
9 | out_dir = download_path + '/burstsr_dataset'
10 |
11 | # Download train folders
12 | for i in range(9):
13 | if not os.path.isfile('{}/train_{:02d}.zip'.format(out_dir, i)):
14 | print('Downloading train_{:02d}'.format(i))
15 |
16 | urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/train_{:02d}.zip'.format(i),
17 | '{}/tmp.zip'.format(out_dir))
18 |
19 | os.rename('{}/tmp.zip'.format(out_dir), '{}/train_{:02d}.zip'.format(out_dir, i))
20 |
21 | # Download val folder
22 | if not os.path.isfile('{}/val.zip'.format(out_dir)):
23 | print('Downloading val')
24 |
25 | urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/val.zip',
26 | '{}/tmp.zip'.format(out_dir))
27 |
28 | os.rename('{}/tmp.zip'.format(out_dir), '{}/val.zip'.format(out_dir))
29 |
30 | # Unpack train set
31 | for i in range(9):
32 | print('Unpacking train_{:02d}'.format(i))
33 | with zipfile.ZipFile('{}/train_{:02d}.zip'.format(out_dir, i), 'r') as zip_ref:
34 | zip_ref.extractall('{}'.format(out_dir))
35 |
36 | # Move files to a common directory
37 | os.makedirs('{}/train'.format(out_dir), exist_ok=True)
38 |
39 | for i in range(9):
40 | file_list = os.listdir('{}/train_{:02d}'.format(out_dir, i))
41 |
42 | for b in file_list:
43 | source_dir = '{}/train_{:02d}/{}'.format(out_dir, i, b)
44 | dst_dir = '{}/train/{}'.format(out_dir, b)
45 |
46 | if os.path.isdir(source_dir):
47 | shutil.move(source_dir, dst_dir)
48 |
49 | # Delete individual subsets
50 | for i in range(9):
51 | shutil.rmtree('{}/train_{:02d}'.format(out_dir, i))
52 |
53 | # Unpack val set
54 | print('Unpacking val')
55 | with zipfile.ZipFile('{}/val.zip'.format(out_dir), 'r') as zip_ref:
56 | zip_ref.extractall('{}'.format(out_dir))
57 |
58 |
59 | def main():
60 | parser = argparse.ArgumentParser(description='Downloads and unpacks BurstSR dataset')
61 | parser.add_argument('path', type=str, help='Path where the dataset will be downloaded')
62 |
63 | args = parser.parse_args()
64 |
65 | download_burstsr_dataset(args.path)
66 |
67 |
68 | if __name__ == '__main__':
69 | main()
70 |
71 |
72 |
--------------------------------------------------------------------------------
/scripts/evaluate.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | rlaunch --cpu=4 --gpu=1 --memory=10240 -- python scripts/evaluate_burstsr_val.py
3 |
--------------------------------------------------------------------------------
/scripts/evaluate_burstsr_val.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from datasets.burstsr_dataset import BurstSRDataset
3 | from utils.metrics import AlignedPSNR
4 | from pwcnet.pwcnet import PWCNet
5 |
6 | root = '/data/dataset/ntire21/burstsr/real/NTIRE/burstsr_dataset'
7 |
8 | class SimpleBaseline:
9 | def __init__(self):
10 | pass
11 |
12 | def __call__(self, burst):
13 | burst_rgb = burst[:, 0, [0, 1, 3]]
14 | burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])
15 | burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')
16 | return burst_rgb
17 |
18 |
19 | def main():
20 | # Load dataset
21 | dataset = BurstSRDataset(root=root,
22 | split='val', burst_size=14, crop_sz=80, random_flip=False)
23 |
24 | # TODO Set your network here
25 | net = SimpleBaseline()
26 |
27 | device = 'cuda'
28 |
29 | # Load alignment network, used in AlignedPSNR
30 | alignment_net = PWCNet(load_pretrained=True,
31 | weights_path='PATH_TO_PWCNET_WEIGHTS')
32 | alignment_net = alignment_net.to(device)
33 | aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40)
34 |
35 | scores_all = []
36 | for idx in range(len(dataset)):
37 | burst, frame_gt, meta_info_burst, meta_info_gt = dataset[idx]
38 | burst = burst.unsqueeze(0).to(device)
39 | frame_gt = frame_gt.unsqueeze(0).to(device)
40 |
41 | net_pred = net(burst)
42 |
43 | # Calculate Aligned PSNR
44 | score = aligned_psnr_fn(net_pred, frame_gt, burst)
45 |
46 | scores_all.append(score)
47 |
48 | mean_psnr = sum(scores_all) / len(scores_all)
49 |
50 | print('Mean PSNR is {:0.3f}'.format(mean_psnr.item()))
51 |
52 |
53 | if __name__ == '__main__':
54 | main()
55 |
--------------------------------------------------------------------------------
/scripts/save_results_synburst_val.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import cv2
3 | from datasets.synthetic_burst_val_set import SyntheticBurstVal
4 | import torch
5 | import numpy as np
6 | import os
7 |
8 |
9 | class SimpleBaseline:
10 | def __init__(self):
11 | pass
12 |
13 | def __call__(self, burst):
14 | burst_rgb = burst[:, 0, [0, 1, 3]]
15 | burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])
16 | burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')
17 | return burst_rgb
18 |
19 |
20 | def main():
21 | dataset = SyntheticBurstVal('PATH_TO_SyntheticBurstVal')
22 | out_dir = 'PATH_WHERE_RESULTS_ARE_SAVED'
23 |
24 | # TODO Set your network here
25 | net = SimpleBaseline()
26 |
27 | device = 'cuda'
28 | os.makedirs(out_dir, exist_ok=True)
29 |
30 | for idx in range(len(dataset)):
31 | burst, burst_name = dataset[idx]
32 |
33 | burst = burst.to(device).unsqueeze(0)
34 |
35 | with torch.no_grad():
36 | net_pred = net(burst)
37 |
38 | # Normalize to 0 2^14 range and convert to numpy array
39 | net_pred_np = (net_pred.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16)
40 |
41 | # Save predictions as png
42 | cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np)
43 |
44 |
45 | if __name__ == '__main__':
46 | main()
47 |
--------------------------------------------------------------------------------
/scripts/test_burstsr_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import cv2
3 | from datasets.burstsr_dataset import BurstSRDataset
4 | from torch.utils.data.dataloader import DataLoader
5 | from utils.metrics import AlignedPSNR
6 | from utils.postprocessing_functions import BurstSRPostProcess
7 | from utils.data_format_utils import convert_dict
8 | from pwcnet.pwcnet import PWCNet
9 |
10 |
11 | def main():
12 | # Load dataset
13 | dataset = BurstSRDataset(root='PATH_TO_BURST_SR',
14 | split='val', burst_size=3, crop_sz=56, random_flip=False)
15 |
16 | data_loader = DataLoader(dataset, batch_size=2)
17 |
18 | # Load alignment network, used in AlignedPSNR
19 | alignment_net = PWCNet(load_pretrained=True,
20 | weights_path='PATH_TO_PWCNET_WEIGHTS')
21 | alignment_net = alignment_net.to('cuda')
22 |
23 | aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40)
24 |
25 | # Postprocessing function to obtain sRGB images
26 | postprocess_fn = BurstSRPostProcess(return_np=True)
27 |
28 | for d in data_loader:
29 | burst, frame_gt, meta_info_burst, meta_info_gt = d
30 |
31 | # A simple baseline which upsamples the base image using bilinear upsampling
32 | burst_rgb = burst[:, 0, [0, 1, 3]]
33 | burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])
34 | burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')
35 |
36 | # Calculate Aligned PSNR
37 | score = aligned_psnr_fn(burst_rgb.cuda(), frame_gt.cuda(), burst.cuda())
38 | print('PSNR is {:0.3f}'.format(score))
39 |
40 | meta_info_gt = convert_dict(meta_info_gt, burst.shape[0])
41 |
42 | # Apply simple post-processing to obtain RGB images
43 | pred_0 = postprocess_fn.process(burst_rgb[0], meta_info_gt[0])
44 | gt_0 = postprocess_fn.process(frame_gt[0], meta_info_gt[0])
45 |
46 | pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR)
47 | gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR)
48 |
49 | # Visualize input, ground truth
50 | cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0)
51 | cv2.imshow('GT', gt_0)
52 |
53 | input_key = cv2.waitKey(0)
54 | if input_key == ord('q'):
55 | return
56 |
57 |
58 | if __name__ == '__main__':
59 | main()
60 |
--------------------------------------------------------------------------------
/scripts/test_synthetic_bursts.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import cv2
3 | from datasets.synthetic_burst_train_set import SyntheticBurst
4 | from torch.utils.data.dataloader import DataLoader
5 | from utils.metrics import PSNR
6 | from utils.postprocessing_functions import SimplePostProcess
7 | from utils.data_format_utils import convert_dict
8 | from datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB
9 |
10 |
11 | def main():
12 | zurich_raw2rgb = ZurichRAW2RGB(root='PATH_TO_ZURICH_RAW_TO_RGB', split='test')
13 | dataset = SyntheticBurst(zurich_raw2rgb, burst_size=3, crop_sz=256)
14 |
15 | data_loader = DataLoader(dataset, batch_size=2)
16 |
17 | # Function to calculate PSNR. Note that the boundary pixels (40 pixels) will be ignored during PSNR computation
18 | psnr_fn = PSNR(boundary_ignore=40)
19 |
20 | # Postprocessing function to obtain sRGB images
21 | postprocess_fn = SimplePostProcess(return_np=True)
22 |
23 | for d in data_loader:
24 | burst, frame_gt, flow_vectors, meta_info = d
25 |
26 | # A simple baseline which upsamples the base image using bilinear upsampling
27 | burst_rgb = burst[:, 0, [0, 1, 3]]
28 | burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])
29 | burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')
30 |
31 | # Calculate PSNR
32 | score = psnr_fn(burst_rgb, frame_gt)
33 |
34 | print('PSNR is {:0.3f}'.format(score))
35 |
36 | meta_info = convert_dict(meta_info, burst.shape[0])
37 |
38 | # Apply simple post-processing to obtain RGB images
39 | pred_0 = postprocess_fn.process(burst_rgb[0], meta_info[0])
40 | gt_0 = postprocess_fn.process(frame_gt[0], meta_info[0])
41 |
42 | pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR)
43 | gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR)
44 |
45 | # Visualize input, ground truth
46 | cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0)
47 | cv2.imshow('GT', gt_0)
48 |
49 | input_key = cv2.waitKey(0)
50 | if input_key == ord('q'):
51 | return
52 |
53 |
54 | if __name__ == '__main__':
55 | main()
56 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import os
6 | from tqdm import tqdm
7 | import random
8 | import utility
9 | from option import args
10 |
11 | from datasets.synthetic_burst_val_set import SyntheticBurstVal
12 | from datasets.burstsr_dataset import flatten_raw_image_batch
13 | import model
14 |
15 | import torch.multiprocessing as mp
16 | import torch.backends.cudnn as cudnn
17 | import torch.utils.data.distributed
18 | import time
19 |
20 |
21 | checkpoint = utility.checkpoint(args)
22 |
23 | def sample_images(burst_size=14):
24 | _burst_size = 14
25 |
26 | ids = random.sample(range(1, _burst_size), k=burst_size - 1)
27 | ids = [0, ] + ids
28 | return ids
29 |
30 |
31 | def ttaup(burst):
32 | burst0 = burst.clone()
33 | burst0 = flatten_raw_image_batch(burst0.unsqueeze(0)).cuda()
34 |
35 | burst3 = burst0.clone().permute(0, 1, 2, 4, 3).cuda()
36 |
37 | ids = sample_images(burst.shape[0])
38 | burst4 = burst0[:, ids].clone()
39 |
40 | return burst0, burst3, burst4
41 |
42 |
43 | def ttadown(bursts):
44 | burst0 = bursts[0]
45 |
46 | burst3 = bursts[1].permute(0, 1, 3, 2)
47 | burst4 = bursts[2]
48 |
49 | out = (burst0 + burst3 + burst4) / 3
50 | return out
51 |
52 |
53 | def main():
54 | mp.spawn(main_worker, nprocs=1, args=(1, args))
55 |
56 |
57 | def main_worker(local_rank, nprocs, args):
58 |
59 | cudnn.benchmark = True
60 | args.local_rank = local_rank
61 | utility.setup(local_rank, nprocs)
62 | torch.cuda.set_device(local_rank)
63 |
64 |
65 | dataset = SyntheticBurstVal(args.root)
66 | out_dir = 'val'
67 |
68 | _model = model.Model(args, checkpoint)
69 |
70 | os.makedirs(out_dir, exist_ok=True)
71 |
72 | tt = []
73 | for idx in tqdm(range(len(dataset))):
74 | burst, burst_name = dataset[idx]
75 | bursts = ttaup(burst)
76 |
77 | srs = []
78 | with torch.no_grad():
79 | for x in bursts:
80 | tic = time.time()
81 | sr = _model(x, 0)
82 | toc = time.time()
83 | tt.append(toc-tic)
84 | srs.append(sr)
85 |
86 | sr = ttadown(srs)
87 | # Normalize to 0 2^14 range and convert to numpy array
88 | net_pred_np = (sr.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16)
89 | cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np)
90 |
91 | print('avg time: {:.4f}'.format(np.mean(tt)))
92 | utility.cleanup()
93 |
94 |
95 | if __name__ == '__main__':
96 | main()
97 |
--------------------------------------------------------------------------------
/test_real.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import os
6 | from tqdm import tqdm
7 | import random
8 | import utility
9 | from option import args
10 | import torchvision.utils as tvutils
11 | from pwcnet.pwcnet import PWCNet
12 |
13 | from utils.postprocessing_functions import BurstSRPostProcess
14 | from datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image_batch, pack_raw_image
15 | from utils.metrics import AlignedPSNR
16 | from utils.data_format_utils import convert_dict
17 | from data_processing.camera_pipeline import demosaic
18 | import model
19 |
20 | import torch.multiprocessing as mp
21 | import torch.backends.cudnn as cudnn
22 | import torch.utils.data.distributed
23 | import time
24 |
25 | from torchsummaryX import summary
26 |
27 |
28 | checkpoint = utility.checkpoint(args)
29 |
30 |
31 | def main():
32 | mp.spawn(main_worker, nprocs=1, args=(1, args))
33 |
34 |
35 | def main_worker(local_rank, nprocs, args):
36 | cudnn.benchmark = True
37 | args.local_rank = local_rank
38 | utility.setup(local_rank, nprocs)
39 | torch.cuda.set_device(local_rank)
40 |
41 | dataset = BurstSRDataset(root=args.root, burst_size=14, crop_sz=80, split='val')
42 | out_dir = 'val/ebsr_real'
43 |
44 | _model = model.Model(args, checkpoint)
45 |
46 | for param in _model.parameters():
47 | param.requires_grad = False
48 |
49 | alignment_net = PWCNet(load_pretrained=True,
50 | weights_path='./pwcnet/pwcnet-network-default.pth')
51 | alignment_net = alignment_net.to('cuda')
52 | for param in alignment_net.parameters():
53 | param.requires_grad = False
54 |
55 | aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40)
56 |
57 | postprocess_fn = BurstSRPostProcess(return_np=True)
58 |
59 | os.makedirs(out_dir, exist_ok=True)
60 |
61 | tt = []
62 | psnrs, ssims, lpipss = [], [], []
63 | for idx in tqdm(range(len(dataset))):
64 | burst_, gt, meta_info_burst, meta_info_gt = dataset[idx]
65 | burst_ = burst_.unsqueeze(0).cuda()
66 | gt = gt.unsqueeze(0).cuda()
67 | burst = flatten_raw_image_batch(burst_)
68 |
69 | with torch.no_grad():
70 | tic = time.time()
71 | sr = _model(burst, 0)
72 | toc = time.time()
73 | tt.append(toc-tic)
74 |
75 | sr_int = (sr.clamp(0.0, 1.0) * 2 ** 14).short()
76 | sr = sr_int.float() / (2 ** 14)
77 |
78 | psnr, ssim, lpips = aligned_psnr_fn(sr, gt, burst_)
79 | psnrs.append(psnr.item())
80 | ssims.append(ssim.item())
81 | lpipss.append(lpips.item())
82 |
83 | os.makedirs(f'{out_dir}/{idx}', exist_ok=True)
84 | sr_ = postprocess_fn.process(sr[0], meta_info_burst)
85 | sr_ = cv2.cvtColor(sr_, cv2.COLOR_RGB2BGR)
86 | cv2.imwrite('{}/{}_sr.png'.format(out_dir, idx), sr_)
87 |
88 | del burst
89 | del sr
90 | del gt
91 |
92 |
93 | print(f'avg PSNR: {np.mean(psnrs):.6f}')
94 | print(f'avg SSIM: {np.mean(ssims):.6f}')
95 | print(f'avg LPIPS: {np.mean(lpipss):.6f}')
96 | print(f' avg time: {np.mean(tt):.6f}')
97 |
98 | # utility.cleanup()
99 |
100 |
101 | if __name__ == '__main__':
102 | main()
103 |
--------------------------------------------------------------------------------
/utility.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import datetime
4 | from multiprocessing import Process
5 | from multiprocessing import Queue
6 |
7 | import matplotlib.pyplot as plt
8 |
9 | import numpy as np
10 | import imageio
11 | import os
12 | import sys
13 |
14 | import torch
15 | import torch.optim as optim
16 | import torch.optim.lr_scheduler as lrs
17 |
18 | import torch.distributed as dist
19 | import matplotlib
20 |
21 | matplotlib.use('Agg')
22 |
23 |
24 | def reduce_mean(tensor, nprocs):
25 | rt = tensor.clone()
26 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
27 | rt /= nprocs
28 | return rt
29 |
30 |
31 | def setup(rank, world_size):
32 | if sys.platform == 'win32':
33 | # Distributed package only covers collective communications with Gloo
34 | # backend and FileStore on Windows platform. Set init_method parameter
35 | # in init_process_group to a local file.
36 | # Example init_method="file:///f:/libtmp/some_file"
37 | init_method = "tcp://localhost:1234"
38 |
39 | # initialize the process group
40 | dist.init_process_group(
41 | "gloo",
42 | init_method=init_method,
43 | rank=rank,
44 | world_size=world_size
45 | )
46 | else:
47 | os.environ['MASTER_ADDR'] = 'localhost'
48 | os.environ['MASTER_PORT'] = '12355'
49 |
50 | # initialize the process group
51 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
52 |
53 |
54 | def cleanup():
55 | dist.destroy_process_group()
56 |
57 |
58 | def mkdir(path):
59 | if not os.path.exists(path):
60 | os.makedirs(path)
61 |
62 |
63 | class timer():
64 | def __init__(self):
65 | self.acc = 0
66 | self.tic()
67 |
68 | def tic(self):
69 | self.t0 = time.time()
70 |
71 | def toc(self, restart=False):
72 | diff = time.time() - self.t0
73 | if restart: self.t0 = time.time()
74 | return diff
75 |
76 | def hold(self):
77 | self.acc += self.toc()
78 |
79 | def release(self):
80 | ret = self.acc
81 | self.acc = 0
82 |
83 | return ret
84 |
85 | def reset(self):
86 | self.acc = 0
87 |
88 |
89 | class checkpoint():
90 | def __init__(self, args):
91 | self.args = args
92 | self.ok = True
93 | self.log = torch.Tensor()
94 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
95 |
96 | if not args.load:
97 | if not args.save:
98 | args.save = now
99 | self.dir = os.path.join('..', 'experiment', args.save)
100 | else:
101 | self.dir = os.path.join('..', 'experiment', args.load)
102 | if os.path.exists(self.dir):
103 | self.log = torch.load(self.get_path('psnr_log.pt'))
104 | print('Continue from epoch {}...'.format(len(self.log)))
105 | else:
106 | args.load = ''
107 |
108 | if args.reset:
109 | os.system('rm -rf ' + self.dir)
110 | args.load = ''
111 |
112 | os.makedirs(self.dir, exist_ok=True)
113 | os.makedirs(self.get_path('model'), exist_ok=True)
114 | # for d in args.data_test:
115 | # os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
116 |
117 | open_type = 'a' if os.path.exists(self.get_path('log.txt')) else 'w'
118 | self.log_file = open(self.get_path('log.txt'), open_type)
119 | with open(self.get_path('config.txt'), open_type) as f:
120 | f.write(now + '\n\n')
121 | for arg in vars(args):
122 | f.write('{}: {}\n'.format(arg, getattr(args, arg)))
123 | f.write('\n')
124 |
125 | self.n_processes = 8
126 |
127 | def get_path(self, *subdir):
128 | return os.path.join(self.dir, *subdir)
129 |
130 | def save(self, trainer, epoch, is_best=False):
131 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
132 | trainer.loss.save(self.dir)
133 | trainer.loss.plot_loss(self.dir, epoch)
134 |
135 | self.plot_psnr(epoch)
136 | trainer.optimizer.save(self.dir)
137 | torch.save(self.log, self.get_path('psnr_log.pt'))
138 |
139 | def add_log(self, log):
140 | self.log = torch.cat([self.log, log])
141 |
142 | def write_log(self, log, refresh=False):
143 | print(log)
144 | self.log_file.write(log + '\n')
145 | if refresh:
146 | self.log_file.close()
147 | self.log_file = open(self.get_path('log.txt'), 'a')
148 |
149 | def done(self):
150 | self.log_file.close()
151 |
152 | def plot_psnr(self, epoch):
153 | axis = np.linspace(1, epoch, epoch)
154 | for idx_data, d in enumerate(self.args.data_test):
155 | label = 'SR on {}'.format(d)
156 | fig = plt.figure()
157 | plt.title(label)
158 | for idx_scale, scale in enumerate(self.args.scale):
159 | plt.plot(
160 | axis,
161 | self.log[:, idx_data, idx_scale].numpy(),
162 | label='Scale {}'.format(scale)
163 | )
164 | plt.legend()
165 | plt.xlabel('Epochs')
166 | plt.ylabel('PSNR')
167 | plt.grid(True)
168 | plt.savefig(self.get_path('test_{}.pdf'.format(d)))
169 | plt.close(fig)
170 |
171 | def begin_background(self):
172 | self.queue = Queue()
173 |
174 | def bg_target(queue):
175 | while True:
176 | if not queue.empty():
177 | filename, tensor = queue.get()
178 | if filename is None: break
179 | imageio.imwrite(filename, tensor.numpy())
180 |
181 | self.process = [
182 | Process(target=bg_target, args=(self.queue,)) \
183 | for _ in range(self.n_processes)
184 | ]
185 |
186 | for p in self.process: p.start()
187 |
188 | def end_background(self):
189 | for _ in range(self.n_processes): self.queue.put((None, None))
190 | while not self.queue.empty(): time.sleep(1)
191 | for p in self.process: p.join()
192 |
193 | def save_results(self, dataset, filename, save_list, scale):
194 | if self.args.save_results:
195 | filename = self.get_path(
196 | 'results-{}'.format(dataset.dataset.name),
197 | '{}_x{}_'.format(filename, scale)
198 | )
199 |
200 | postfix = ('SR', 'LR', 'HR')
201 | for v, p in zip(save_list, postfix):
202 | normalized = v[0].mul(255 / self.args.rgb_range)
203 | tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
204 | self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
205 |
206 |
207 | def quantize(img, rgb_range):
208 | pixel_range = 255 / rgb_range
209 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
210 |
211 |
212 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
213 | if hr.nelement() == 1: return 0
214 |
215 | diff = (sr - hr) / rgb_range
216 | if dataset and dataset.dataset.benchmark:
217 | shave = scale
218 | if diff.size(1) > 1:
219 | gray_coeffs = [65.738, 129.057, 25.064]
220 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
221 | diff = diff.mul(convert).sum(dim=1)
222 | else:
223 | shave = scale + 6
224 |
225 | valid = diff[..., shave:-shave, shave:-shave]
226 | mse = valid.pow(2).mean()
227 |
228 | return -10 * math.log10(mse)
229 |
230 |
231 | def make_optimizer(args, target):
232 | '''
233 | make optimizer and scheduler together
234 | '''
235 | # optimizer
236 | trainable = filter(lambda x: x.requires_grad, target.parameters())
237 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
238 |
239 | if args.optimizer == 'SGD':
240 | optimizer_class = optim.SGD
241 | kwargs_optimizer['momentum'] = args.momentum
242 | elif args.optimizer == 'ADAM':
243 | optimizer_class = optim.Adam
244 | kwargs_optimizer['betas'] = args.betas
245 | kwargs_optimizer['eps'] = args.epsilon
246 | elif args.optimizer == 'RMSprop':
247 | optimizer_class = optim.RMSprop
248 | kwargs_optimizer['eps'] = args.epsilon
249 |
250 | # scheduler
251 | milestones = list(map(lambda x: int(x), args.decay.split('-')))
252 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
253 | scheduler_class = lrs.MultiStepLR
254 |
255 | class CustomOptimizer(optimizer_class):
256 | def __init__(self, *args, **kwargs):
257 | super(CustomOptimizer, self).__init__(*args, **kwargs)
258 |
259 | def _register_scheduler(self, scheduler_class, **kwargs):
260 | self.scheduler = scheduler_class(self, **kwargs)
261 |
262 | def save(self, save_dir):
263 | torch.save(self.state_dict(), self.get_dir(save_dir))
264 |
265 | def load(self, load_dir, epoch=1):
266 | self.load_state_dict(torch.load(self.get_dir(load_dir)))
267 | if epoch > 1:
268 | for _ in range(epoch): self.scheduler.step()
269 |
270 | def get_dir(self, dir_path):
271 | return os.path.join(dir_path, 'optimizer.pt')
272 |
273 | def schedule(self):
274 | self.scheduler.step()
275 |
276 | def get_lr(self):
277 | return self.scheduler.get_last_lr()[0]
278 |
279 | def get_last_epoch(self):
280 | return self.scheduler.last_epoch
281 |
282 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
283 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
284 | return optimizer
285 |
286 |
287 | def write_gray_to_tfboard(img):
288 | img_debug = img[0, ...].detach().cpu().numpy()
289 |
290 | # img_debug = cv2.normalize(img_debug, None, 0, 255,
291 | # cv2.NORM_MINMAX, cv2.CV_8U)
292 | img_debug = img_debug * 255
293 | img_debug = np.clip(img_debug, 0, 255)
294 | img_debug = img_debug.astype(np.uint8)
295 | return img_debug[0, ...]
296 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Algolzw/EBSR/0ae35fd31c67c0cce1a8bed16788f6b253e83036/utils/__init__.py
--------------------------------------------------------------------------------
/utils/data_format_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import cv2 as cv
4 |
5 |
6 | def numpy_to_torch(a: np.ndarray):
7 | return torch.from_numpy(a).float().permute(2, 0, 1)
8 |
9 |
10 | def torch_to_numpy(a: torch.Tensor):
11 | return a.permute(1, 2, 0).cpu().numpy()
12 |
13 |
14 | def torch_to_npimage(a: torch.Tensor, unnormalize=True):
15 | a_np = torch_to_numpy(a)
16 |
17 | if unnormalize:
18 | a_np = a_np * 255
19 | a_np = a_np.astype(np.uint8)
20 | return cv.cvtColor(a_np, cv.COLOR_RGB2BGR)
21 |
22 |
23 | def npimage_to_torch(a, normalize=True, input_bgr=True):
24 | if input_bgr:
25 | a = cv.cvtColor(a, cv.COLOR_BGR2RGB)
26 | a_t = numpy_to_torch(a)
27 |
28 | if normalize:
29 | a_t = a_t / 255.0
30 |
31 | return a_t
32 |
33 |
34 | def convert_dict(base_dict, batch_sz):
35 | out_dict = []
36 | for b_elem in range(batch_sz):
37 | b_info = {}
38 | for k, v in base_dict.items():
39 | if isinstance(v, (list, torch.Tensor)):
40 | b_info[k] = v[b_elem]
41 | out_dict.append(b_info)
42 |
43 | return out_dict
--------------------------------------------------------------------------------
/utils/debayer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn
3 | import torch.nn.functional
4 |
5 | class Debayer3x3(torch.nn.Module):
6 | '''Demosaicing of Bayer images using 3x3 convolutions.
7 |
8 | Requires BG-Bayer color filter array layout. That is,
9 | the image[1,1]='B', image[1,2]='G'. This corresponds
10 | to OpenCV naming conventions.
11 |
12 | Compared to Debayer2x2 this method does not use upsampling.
13 | Instead, we identify five 3x3 interpolation kernels that
14 | are sufficient to reconstruct every color channel at every
15 | pixel location.
16 |
17 | We convolve the image with these 5 kernels using stride=1
18 | and a one pixel replication padding. Finally, we gather
19 | the correct channel values for each pixel location. Todo so,
20 | we recognize that the Bayer pattern repeats horizontally and
21 | vertically every 2 pixels. Therefore, we define the correct
22 | index lookups for a 2x2 grid cell and then repeat to image
23 | dimensions.
24 |
25 | Note, in every 2x2 grid cell we have red, blue and two greens
26 | (G1,G2). The lookups for the two greens differ.
27 | '''
28 |
29 | def __init__(self):
30 | super(Debayer3x3, self).__init__()
31 |
32 | self.kernels = torch.nn.Parameter(
33 | torch.tensor([
34 | [0,0,0],
35 | [0,1,0],
36 | [0,0,0],
37 |
38 | [0, 0.25, 0],
39 | [0.25, 0, 0.25],
40 | [0, 0.25, 0],
41 |
42 | [0.25, 0, 0.25],
43 | [0, 0, 0],
44 | [0.25, 0, 0.25],
45 |
46 | [0, 0, 0],
47 | [0.5, 0, 0.5],
48 | [0, 0, 0],
49 |
50 | [0, 0.5, 0],
51 | [0, 0, 0],
52 | [0, 0.5, 0],
53 | ]).view(5,1,3,3), requires_grad=False
54 | )
55 |
56 |
57 | self.index = torch.nn.Parameter(
58 | torch.tensor([
59 | # dest channel r
60 | [0, 3], # pixel is R,G1
61 | [4, 2], # pixel is G2,B
62 | # dest channel g
63 | [1, 0], # pixel is R,G1
64 | [0, 1], # pixel is G2,B
65 | # dest channel b
66 | [2, 4], # pixel is R,G1
67 | [3, 0], # pixel is G2,B
68 | ]).view(1,3,2,2), requires_grad=False
69 | )
70 |
71 | def forward(self, x):
72 | '''Debayer image.
73 |
74 | Parameters
75 | ----------
76 | x : Bx1xHxW tensor
77 | Images to debayer
78 |
79 | Returns
80 | -------
81 | rgb : Bx3xHxW tensor
82 | Color images in RGB channel order.
83 | '''
84 | B,C,H,W = x.shape
85 |
86 | x = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate')
87 | c = torch.nn.functional.conv2d(x, self.kernels, stride=1)
88 | rgb = torch.gather(c, 1, self.index.repeat(B,1,H//2,W//2))
89 | return rgb
90 |
91 | class Debayer2x2(torch.nn.Module):
92 | '''Demosaicing of Bayer images using 2x2 convolutions.
93 |
94 | Requires BG-Bayer color filter array layout. That is,
95 | the image[1,1]='B', image[1,2]='G'. This corresponds
96 | to OpenCV naming conventions.
97 | '''
98 |
99 | def __init__(self):
100 | super(Debayer2x2, self).__init__()
101 |
102 | self.kernels = torch.nn.Parameter(
103 | torch.tensor([
104 | [1, 0],
105 | [0, 0],
106 |
107 | [0, 0.5],
108 | [0.5, 0],
109 |
110 | [0, 0],
111 | [0, 1],
112 | ]).view(3,1,2,2), requires_grad=False
113 | )
114 |
115 | def forward(self, x):
116 | '''Debayer image.
117 |
118 | Parameters
119 | ----------
120 | x : Bx1xHxW tensor
121 | Images to debayer
122 |
123 | Returns
124 | -------
125 | rgb : Bx3xHxW tensor
126 | Color images in RGB channel order.
127 | '''
128 |
129 | x = torch.nn.functional.conv2d(x, self.kernels, stride=2)
130 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
131 | return x
132 |
133 | class DebayerSplit(torch.nn.Module):
134 | '''Demosaicing of Bayer images using 3x3 green convolution and red,blue upsampling.
135 |
136 | Requires BG-Bayer color filter array layout. That is,
137 | the image[1,1]='B', image[1,2]='G'. This corresponds
138 | to OpenCV naming conventions.
139 | '''
140 | def __init__(self):
141 | super().__init__()
142 |
143 | self.pad = torch.nn.ReflectionPad2d(1)
144 | self.kernel = torch.nn.Parameter(
145 | torch.tensor([
146 | [0,1,0],
147 | [1,0,1],
148 | [0,1,0]
149 | ])[None, None] * 0.25)
150 |
151 | def forward(self, x):
152 | '''Debayer image.
153 |
154 | Parameters
155 | ----------
156 | x : Bx1xHxW tensor
157 | Images to debayer
158 |
159 | Returns
160 | -------
161 | rgb : Bx3xHxW tensor
162 | Color images in RGB channel order.
163 | '''
164 | B,_,H,W = x.shape
165 | red = x[:, :, ::2, ::2]
166 | blue = x[:, :, 1::2, 1::2]
167 |
168 | green = torch.nn.functional.conv2d(self.pad(x), self.kernel)
169 | green[:, :, ::2, 1::2] = x[:, :, ::2, 1::2]
170 | green[:, :, 1::2, ::2] = x[:, :, 1::2, ::2]
171 |
172 | return torch.cat((
173 | torch.nn.functional.interpolate(red, size=(H, W), mode='bilinear', align_corners=False),
174 | green,
175 | torch.nn.functional.interpolate(blue, size=(H, W), mode='bilinear', align_corners=False)),
176 | dim=1)
--------------------------------------------------------------------------------
/utils/interp_methods.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 |
3 | try:
4 | import torch
5 | except ImportError:
6 | torch = None
7 |
8 | try:
9 | import numpy
10 | except ImportError:
11 | numpy = None
12 |
13 | if numpy is None and torch is None:
14 | raise ImportError("Must have either Numpy or PyTorch but both not found")
15 |
16 |
17 | def set_framework_dependencies(x):
18 | if type(x) is numpy.ndarray:
19 | to_dtype = lambda a: a
20 | fw = numpy
21 | else:
22 | to_dtype = lambda a: a.to(x.dtype)
23 | fw = torch
24 | eps = fw.finfo(fw.float32).eps
25 | return fw, to_dtype, eps
26 |
27 |
28 | def support_sz(sz):
29 | def wrapper(f):
30 | f.support_sz = sz
31 | return f
32 | return wrapper
33 |
34 | @support_sz(4)
35 | def cubic(x):
36 | fw, to_dtype, eps = set_framework_dependencies(x)
37 | absx = fw.abs(x)
38 | absx2 = absx ** 2
39 | absx3 = absx ** 3
40 | return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
41 | (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *
42 | to_dtype((1. < absx) & (absx <= 2.)))
43 |
44 | @support_sz(4)
45 | def lanczos2(x):
46 | fw, to_dtype, eps = set_framework_dependencies(x)
47 | return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /
48 | ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))
49 |
50 | @support_sz(6)
51 | def lanczos3(x):
52 | fw, to_dtype, eps = set_framework_dependencies(x)
53 | return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /
54 | ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))
55 |
56 | @support_sz(2)
57 | def linear(x):
58 | fw, to_dtype, eps = set_framework_dependencies(x)
59 | return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *
60 | to_dtype((0 <= x) & (x <= 1)))
61 |
62 | @support_sz(1)
63 | def box(x):
64 | fw, to_dtype, eps = set_framework_dependencies(x)
65 | return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))
66 |
--------------------------------------------------------------------------------
/utils/postprocessing_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import utils.data_format_utils as df_utils
4 | from data_processing.camera_pipeline import apply_gains, apply_ccm, apply_smoothstep, gamma_compression
5 |
6 |
7 | class SimplePostProcess:
8 | def __init__(self, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False):
9 | self.gains = gains
10 | self.ccm = ccm
11 | self.gamma = gamma
12 | self.smoothstep = smoothstep
13 | self.return_np = return_np
14 |
15 | def process(self, image, meta_info):
16 | return process_linear_image_rgb(image, meta_info, self.gains, self.ccm, self.gamma,
17 | self.smoothstep, self.return_np)
18 |
19 |
20 | def process_linear_image_rgb(image, meta_info, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False):
21 | if gains:
22 | image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])
23 |
24 | if ccm:
25 | image = apply_ccm(image, meta_info['cam2rgb'])
26 |
27 | if meta_info['gamma'] and gamma:
28 | image = gamma_compression(image)
29 |
30 | if meta_info['smoothstep'] and smoothstep:
31 | image = apply_smoothstep(image)
32 |
33 | image = image.clamp(0.0, 1.0)
34 |
35 | if return_np:
36 | image = df_utils.torch_to_npimage(image)
37 | return image
38 |
39 |
40 | class BurstSRPostProcess:
41 | def __init__(self, no_white_balance=False, gamma=True, smoothstep=True, return_np=False):
42 | self.no_white_balance = no_white_balance
43 | self.gamma = gamma
44 | self.smoothstep = smoothstep
45 | self.return_np = return_np
46 |
47 | def process(self, image, meta_info, external_norm_factor=None):
48 | return process_burstsr_image_rgb(image, meta_info, external_norm_factor=external_norm_factor,
49 | no_white_balance=self.no_white_balance, gamma=self.gamma,
50 | smoothstep=self.smoothstep, return_np=self.return_np)
51 |
52 |
53 | def process_burstsr_image_rgb(im, meta_info, return_np=False, external_norm_factor=None, gamma=True, smoothstep=True,
54 | no_white_balance=False):
55 | im = im * meta_info.get('norm_factor', 1.0)
56 |
57 | if not meta_info.get('black_level_subtracted', False):
58 | im = (im - torch.tensor(meta_info['black_level'])[[0, 1, -1]].view(3, 1, 1).to(im.device))
59 |
60 | if not meta_info.get('while_balance_applied', False) and not no_white_balance:
61 | im = im * (meta_info['cam_wb'][[0, 1, -1]].view(3, 1, 1) / meta_info['cam_wb'][1]).to(im.device)
62 |
63 | im_out = im
64 |
65 | if external_norm_factor is None:
66 | im_out = im_out / im_out.max()
67 | else:
68 | im_out = im_out / external_norm_factor
69 |
70 | im_out = im_out.clamp(0.0, 1.0)
71 |
72 | if gamma:
73 | im_out = im_out ** (1.0 / 2.2)
74 |
75 | if smoothstep:
76 | # Smooth curve
77 | im_out = 3 * im_out ** 2 - 2 * im_out ** 3
78 |
79 | if return_np:
80 | im_out = im_out.permute(1, 2, 0).cpu().numpy() * 255.0
81 | im_out = im_out.astype(np.uint8)
82 |
83 | return im_out
84 |
--------------------------------------------------------------------------------
/utils/spatial_color_alignment.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def gauss_1d(sz, sigma, center, end_pad=0, density=False):
7 | """ Returns a 1-D Gaussian """
8 | k = torch.arange(-(sz-1)/2, (sz+1)/2 + end_pad).reshape(1, -1)
9 | gauss = torch.exp(-1.0/(2*sigma**2) * (k - center.reshape(-1, 1))**2)
10 | if density:
11 | gauss /= math.sqrt(2*math.pi) * sigma
12 | return gauss
13 |
14 |
15 | def gauss_2d(sz, sigma, center, end_pad=(0, 0), density=False):
16 | """ Returns a 2-D Gaussian """
17 | if isinstance(sigma, (float, int)):
18 | sigma = (sigma, sigma)
19 | if isinstance(sz, int):
20 | sz = (sz, sz)
21 |
22 | if isinstance(center, (list, tuple)):
23 | center = torch.tensor(center).view(1, 2)
24 |
25 | return gauss_1d(sz[0], sigma[0], center[:, 0], end_pad[0], density).reshape(center.shape[0], 1, -1) * \
26 | gauss_1d(sz[1], sigma[1], center[:, 1], end_pad[1], density).reshape(center.shape[0], -1, 1)
27 |
28 |
29 | def get_gaussian_kernel(sd):
30 | """ Returns a Gaussian kernel with standard deviation sd """
31 | ksz = int(4 * sd + 1)
32 | assert ksz % 2 == 1
33 | K = gauss_2d(ksz, sd, (0.0, 0.0), density=True)
34 | K = K / K.sum()
35 | return K.unsqueeze(0), ksz
36 |
37 |
38 | def apply_kernel(im, ksz, gauss_kernel):
39 | shape = im.shape
40 | im = im.view(-1, 1, *im.shape[-2:])
41 |
42 | pad = [ksz // 2, ksz // 2, ksz // 2, ksz // 2]
43 | im = F.pad(im, pad, mode='reflect')
44 | im_mean = F.conv2d(im, gauss_kernel).view(shape)
45 | return im_mean
46 |
47 |
48 | def match_colors(im_ref, im_q, im_test, ksz, gauss_kernel):
49 | """ Estimates a color transformation matrix between im_ref and im_q. Applies the estimated transformation to
50 | im_test
51 | """
52 | gauss_kernel = gauss_kernel.to(im_ref.device)
53 | bi = 5
54 |
55 | # Apply Gaussian smoothing
56 | im_ref_mean = apply_kernel(im_ref, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous()
57 | im_q_mean = apply_kernel(im_q, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous()
58 |
59 | im_ref_mean_re = im_ref_mean.view(*im_ref_mean.shape[:2], -1)
60 | im_q_mean_re = im_q_mean.view(*im_q_mean.shape[:2], -1)
61 |
62 | # Estimate color transformation matrix by minimizing the least squares error
63 | c_mat_all = []
64 | for ir, iq in zip(im_ref_mean_re, im_q_mean_re):
65 | c = torch.lstsq(ir.t(), iq.t())
66 | c = c.solution[:3]
67 | c_mat_all.append(c)
68 |
69 | c_mat = torch.stack(c_mat_all, dim=0)
70 | im_q_mean_conv = torch.matmul(im_q_mean_re.permute(0, 2, 1), c_mat).permute(0, 2, 1)
71 | im_q_mean_conv = im_q_mean_conv.view(im_q_mean.shape)
72 |
73 | err = ((im_q_mean_conv - im_ref_mean) * 255.0).norm(dim=1)
74 |
75 | thresh = 20
76 |
77 | # If error is larger than a threshold, ignore these pixels
78 | valid = err < thresh
79 |
80 | pad = (im_q.shape[-1] - valid.shape[-1]) // 2
81 | pad = [pad, pad, pad, pad]
82 | valid = F.pad(valid, pad)
83 |
84 | upsample_factor = im_test.shape[-1] / valid.shape[-1]
85 | valid = F.interpolate(valid.unsqueeze(1).float(), scale_factor=upsample_factor, mode='bilinear', align_corners=False)
86 | valid = valid > 0.9
87 |
88 | # Apply the transformation to test image
89 | im_test_re = im_test.view(*im_test.shape[:2], -1)
90 | im_t_conv = torch.matmul(im_test_re.permute(0, 2, 1), c_mat).permute(0, 2, 1)
91 | im_t_conv = im_t_conv.view(im_test.shape)
92 |
93 | return im_t_conv, valid
94 |
95 |
--------------------------------------------------------------------------------
/utils/stn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SpatialTransformer(nn.Module):
7 | """
8 | [SpatialTransformer] represesents a spatial transformation block
9 | that uses the output from the UNet to preform an grid_sample
10 | https://pytorch.org/docs/stable/nn.functional.html#grid-sample
11 | """
12 | def __init__(self, size, mode='bilinear'):
13 | """
14 | Instiatiate the block
15 | :param size: size of input to the spatial transformer block
16 | :param mode: method of interpolation for grid_sampler
17 | """
18 | super(OldSpatialTransformer, self).__init__()
19 | if isinstance(size, int):
20 | size = (size, size)
21 | # Create sampling grid
22 | vectors = [ torch.arange(0, s) for s in size ]
23 | grids = torch.meshgrid(vectors)
24 | grid = torch.stack(grids) # y, x, z
25 | grid = torch.unsqueeze(grid, 0) #add batch
26 | grid = grid.type(torch.FloatTensor)
27 | self.register_buffer('grid', grid)
28 |
29 | self.mode = mode
30 |
31 | def forward(self, src, flow):
32 | """
33 | Push the src and flow through the spatial transform block
34 | :param src: the original moving image
35 | :param flow: the output from the U-Net
36 | """
37 | new_locs = self.grid + flow
38 |
39 | shape = flow.shape[2:]
40 |
41 | # Need to normalize grid values to [-1, 1] for resampler
42 | for i in range(len(shape)):
43 | new_locs[:,i,...] = 2*(new_locs[:,i,...]/(shape[i]-1) - 0.5)
44 |
45 | if len(shape) == 2:
46 | new_locs = new_locs.permute(0, 2, 3, 1)
47 | new_locs = new_locs[..., [1,0]]
48 | elif len(shape) == 3:
49 | new_locs = new_locs.permute(0, 2, 3, 4, 1)
50 | new_locs = new_locs[..., [2,1,0]]
51 |
52 | return F.grid_sample(src, new_locs, mode=self.mode, align_corners=True)
53 |
--------------------------------------------------------------------------------
/utils/warp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def warp(feat, flow, mode='bilinear', padding_mode='zeros'):
7 | """
8 | warp an image/tensor (im2) back to im1, according to the optical flow im1 --> im2
9 |
10 | input flow must be in format (x, y) at every pixel
11 | feat: [B, C, H, W] (im2)
12 | flow: [B, 2, H, W] flow (x, y)
13 |
14 | """
15 | B, C, H, W = feat.size()
16 | # print(feat.device, flow.device)
17 |
18 | # mesh grid
19 | rowv, colv = torch.meshgrid([torch.arange(0.5, H + 0.5), torch.arange(0.5, W + 0.5)])
20 | grid = torch.stack((colv, rowv), dim=0).unsqueeze(0).float().to(flow.device)
21 | # print(grid.device, flow.device, feat.device)
22 | # grid = grid.cuda()
23 | grid = grid + flow
24 |
25 | # scale grid to [-1,1]
26 | grid_norm_c = 2.0 * grid[:, 0] / W - 1.0
27 | grid_norm_r = 2.0 * grid[:, 1] / H - 1.0
28 |
29 | grid_norm = torch.stack((grid_norm_c, grid_norm_r), dim=1).to(flow.device)
30 |
31 | grid_norm = grid_norm.permute(0, 2, 3, 1)
32 |
33 | output = F.grid_sample(feat, grid_norm, mode=mode, align_corners=False, padding_mode=padding_mode)
34 |
35 | return output
36 |
--------------------------------------------------------------------------------