├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── align.py ├── demo.ipynb └── illustrations ├── before_and_after.jpg ├── burst_sample.jpg └── merged_image.jpg /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .ipynb_checkpoints 3 | __pycache__ 4 | deprecated 5 | demo_drive.ipynb 6 | playground.ipynb 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Martin Marek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HDR+ PyTorch 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/martin-marek/hdr-plus-pytorch/blob/main/demo.ipynb) 4 | 5 | This is a simplified PyTorch implementation of HDR+, the backbone of computational photography in Google Pixel phones, described in [Burst photography for high dynamic range and low-light imaging on mobile cameras](http://static.googleusercontent.com/media/www.hdrplusdata.org/en//hdrplus.pdf). Using a free Colab GPU, aligning 20MP RAW images takes ~200 ms/frame. 6 | 7 | If you would like to use HDR+ in practice (rather than research), please check out my open-source Mac app [Burst Photo](https://burst.photo). It has a GUI, supports robust merge, and uses Adobe DNG SDK (instead of LibRaw), significantly improving image quality. 8 | 9 | # Example 10 | 11 | I took a burst of 35 images at ISO 12,800 on Sony RX100-V and boosted it by +2EV. Here's a comparison of a [single image](illustrations/burst_sample.jpg) from the burst vs. a [merge of all the images](illustrations/merged_image.jpg). 12 | 13 | ![alt text](illustrations/before_and_after.jpg) 14 | 15 | # Usage 16 | 17 | Here's a minimal example to align and merge a burst of raw images. For more, see the [Colab Notebook](https://colab.research.google.com/github/martin-marek/hdr-plus-pytorch/blob/main/demo.ipynb). 18 | 19 | ```python 20 | import torch, align 21 | images = torch.zeros([5, 1, 1000, 1000]) 22 | merged_image = align.align_and_merge(images) 23 | ``` 24 | 25 | # Implementation details 26 | 27 | The implementation heavily relies on [first class dimensions](https://github.com/facebookresearch/torchdim), which allows for vectorized code that resembles the use of explicit loops. [Previous versions](https://github.com/martin-marek/hdr-plus-pytorch/blob/322c6039393074cefd9c5082006b509d5121aad1/align.py) of this repo used standard NumPy-style broadcasting but that was slower, harder to read, and required more loc. 28 | 29 | # Features 30 | - [x] jpeg support 31 | - [x] RAW support 32 | - [x] simple merge 33 | - [ ] robust merge 34 | - [x] tile comparison in pixel space 35 | - [ ] tile comparison in Fourier space 36 | - [x] CUDA support 37 | - [x] CPU support (very slow) 38 | - [ ] color post-processing 39 | - [ ] automatic selection of the reference image 40 | 41 | # Citation 42 | 43 | ```bibtex 44 | @article{hasinoff2016burst, 45 | title={Burst photography for high dynamic range and low-light imaging on mobile cameras}, 46 | author={Hasinoff, Samuel W and Sharlet, Dillon and Geiss, Ryan and Adams, Andrew and Barron, Jonathan T and Kainz, Florian and Chen, Jiawen and Levoy, Marc}, 47 | journal={ACM Transactions on Graphics (ToG)}, 48 | volume={35}, 49 | number={6}, 50 | pages={1--12}, 51 | year={2016}, 52 | publisher={ACM New York, NY, USA} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /align.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from functorch.dim import dims 5 | from torch import Tensor 6 | from typing import List 7 | 8 | 9 | def upscale_previous_alignment(alignment: Tensor, 10 | downscale_factor: int, 11 | w: int, h: int 12 | ) -> Tensor: 13 | """ 14 | When layers in an image pyramid are iteratively compared, 15 | the absolute pixel distances in each layer represent different 16 | relative distances. This function interpolates an optical flow 17 | from one resolution to another, taking care to scale the values. 18 | """ 19 | alignment = alignment[None].float() 20 | alignment = downscale_factor * F.interpolate(alignment, size=(h, w), mode='nearest') 21 | alignment = alignment[0].int() 22 | return alignment 23 | 24 | 25 | def build_pyramid(image: Tensor, 26 | downscale_factor_list: List[int], 27 | ) -> List[Tensor]: 28 | """ 29 | Create an image pyramid from a single image. 30 | """ 31 | # if the input image has multiple channels (e.g. RGB), average them to obtain a single-channel image 32 | layer = torch.mean(image, 0, keepdim=True) 33 | 34 | # iteratively build each level in the image pyramid 35 | pyramid = [] 36 | for downscale_factor in downscale_factor_list: 37 | layer = F.avg_pool2d(layer, downscale_factor) 38 | pyramid.append(layer) 39 | return pyramid 40 | 41 | 42 | def align_layers(ref_layer: Tensor, 43 | comp_layer: Tensor, 44 | prev_alignment: Tensor, 45 | tile_size: int, 46 | search_dist: int, 47 | downscale_factor: int = 1 48 | ) -> Tensor: 49 | """ 50 | Estimates the optical flow between layers of two distinct image pyramids. 51 | 52 | Args: 53 | comp_layer: the layer to be aligned to `ref_layer` 54 | prev_alignment: alignment from a coarser pyramid layer 55 | downscale_factor: scaling factor between the previous layer and current layer, only required if `prev_alignment` is not zeros 56 | """ 57 | device = ref_layer.device 58 | 59 | # compute number of tiles in a layer such that they overlap 60 | n_channels, layer_height, layer_width = ref_layer.shape 61 | n_tiles_y = layer_height // (tile_size // 2) - 1 62 | n_tiles_x = layer_width // (tile_size // 2) - 1 63 | 64 | # upscale previous alignment 65 | prev_alignment = upscale_previous_alignment(prev_alignment, downscale_factor, n_tiles_x, n_tiles_y) 66 | 67 | # get reference image tiles (no shift) 68 | channel, tile_idx_y, tile_idx_x, tile_h, tile_w = dims(sizes=[None, n_tiles_y, n_tiles_x, tile_size, tile_size]) 69 | x_min = torch.linspace(0, layer_width-tile_size, n_tiles_x, dtype=torch.int32, device=device)[tile_idx_x] 70 | y_min = torch.linspace(0, layer_height-tile_size, n_tiles_y, dtype=torch.int32, device=device)[tile_idx_y] 71 | x = x_min + tile_w 72 | y = y_min + tile_h 73 | ref_tiles = ref_layer[channel, y, x] 74 | 75 | # get comparison image tiles (shifted) 76 | shift_x, shift_y = dims(sizes=[1+2*search_dist, 1+2*search_dist]) 77 | x = x + prev_alignment[0, tile_idx_y, tile_idx_x] + (shift_x - search_dist) 78 | y = y + prev_alignment[1, tile_idx_y, tile_idx_x] + (shift_y - search_dist) 79 | comp_tiles = comp_layer[channel, y.clip(0, layer_height-1), x.clip(0, layer_width-1)] 80 | 81 | # compute the difference between the reference and comparison tiles 82 | diff = (ref_tiles - comp_tiles).abs().sum([channel, tile_w, tile_h]) 83 | diff = diff.order(tile_idx_y, tile_idx_x, (shift_y, shift_x)) 84 | 85 | # set the difference value for tiles outside of the frame to infinity 86 | tile_is_outside_layer = ((x<0)^(x>=layer_width)).sum(tile_w) + ((y<0)^(y>=layer_height)).sum(tile_h) > 0 87 | tile_is_outside_layer = tile_is_outside_layer.order(tile_idx_y, tile_idx_x, (shift_y, shift_x)) 88 | diff[tile_is_outside_layer] = float('inf') 89 | 90 | # find which shift (dx, dy) between the reference and comparison tiles yields the lowest loss 91 | min_idx = torch.argmin(diff, -1) 92 | dy = min_idx // (2*search_dist+1) - search_dist 93 | dx = min_idx % (2*search_dist+1) - search_dist 94 | 95 | # save the current alignment 96 | alignment = torch.stack([dx, dy], 0) # [2, n_tiles_y, n_tiles_x] 97 | 98 | # combine the current alignment with the previous alignment 99 | alignment += prev_alignment 100 | 101 | return alignment 102 | 103 | 104 | def warp_image(image: Tensor, alignment: Tensor) -> Tensor: 105 | """ 106 | Warps image using optical flow. 107 | """ 108 | dx, dy = alignment 109 | C, H, W = image.shape 110 | channel, y, x = dims(sizes=[C, H, W]) 111 | warped = image[channel, (y + dy[y, x]).clamp(0, H-1), (x + dx[y, x]).clamp(0, W-1)] 112 | return warped.order(channel, y, x) 113 | 114 | 115 | def align_and_merge(images: Tensor, 116 | ref_idx: int = 0, 117 | device: torch.device = torch.device('cpu'), 118 | min_layer_res: int = 64, 119 | tile_size: int = 16, 120 | search_dist: int = 2, 121 | ) -> Tensor: 122 | """ 123 | Align and merge a burst of images. The input and output tensors are assumed to be on CPU device, to reduce GPU memory requirements. 124 | 125 | Args: 126 | images: burst of shape (num_frames, channels, height, width) 127 | ref_idx: index of the reference image (all images are alinged to this image) 128 | device: the PyTorch device to use (either 'cpu' or 'cuda') 129 | min_layer_res: size of the smallest pyramid layer 130 | tile_size: size of tiles in each pyramid layer 131 | """ 132 | 133 | # check the shape of the burst 134 | N, C, H, W = images.shape 135 | 136 | # build a pyramid from the reference image 137 | n_layers = math.ceil(math.log2(min(H, W) / min_layer_res)) 138 | downscale_factor_list = n_layers*[2] 139 | ref_idx = torch.tensor(ref_idx) 140 | ref_image = images[ref_idx].to(device) 141 | ref_pyramid = build_pyramid(ref_image, downscale_factor_list) 142 | 143 | # iterate through the comparison images 144 | merged_image = ref_image.clone() / N 145 | comp_idxs = torch.arange(N)[torch.arange(N)!=ref_idx] 146 | for i, comp_idx in enumerate(comp_idxs): 147 | 148 | # build a pyramid from the comparison image 149 | comp_image = images[comp_idx].to(device) 150 | comp_pyramid = build_pyramid(comp_image, downscale_factor_list) 151 | 152 | # start off with default alignment (no shift between images) 153 | alignment = torch.zeros([2, 1, 1], dtype=torch.int32, device=device) 154 | 155 | # iteratively improve the alignment in each pyramid layer 156 | for layer_idx in torch.flip(torch.arange(len(ref_pyramid)), [0]): 157 | downscale_factor = downscale_factor_list[min(layer_idx+1, len(ref_pyramid)-1)] 158 | alignment = align_layers(ref_pyramid[layer_idx], comp_pyramid[layer_idx], 159 | alignment, tile_size, search_dist, downscale_factor) 160 | 161 | # scale the alignment to the resolution of the original image 162 | alignment = upscale_previous_alignment(alignment, downscale_factor_list[0], W, H) 163 | 164 | # warp the comparison image based on the computed alignment 165 | comp_image_aligned = warp_image(comp_image, alignment) 166 | 167 | # add the aligned image to the output 168 | merged_image += comp_image_aligned / N 169 | 170 | merged_image = merged_image.cpu() 171 | 172 | return merged_image 173 | -------------------------------------------------------------------------------- /illustrations/before_and_after.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martin-marek/hdr-plus-pytorch/e7091c33b0e3417f84e72d70ad2081b66daee56d/illustrations/before_and_after.jpg -------------------------------------------------------------------------------- /illustrations/burst_sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martin-marek/hdr-plus-pytorch/e7091c33b0e3417f84e72d70ad2081b66daee56d/illustrations/burst_sample.jpg -------------------------------------------------------------------------------- /illustrations/merged_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martin-marek/hdr-plus-pytorch/e7091c33b0e3417f84e72d70ad2081b66daee56d/illustrations/merged_image.jpg --------------------------------------------------------------------------------