├── .gitignore ├── Examples ├── Banner.png ├── chairs_input.jpg ├── chairs_output.jpg ├── empty_gui_v2.png ├── flowers_in.jpg ├── flowers_out.jpg ├── input_sunflower.jpg ├── output_sunflower.jpg └── progress_bar.png ├── HDRPlus2020.pdf ├── Images └── gallery.jpg ├── Output └── .gitignore ├── README.md ├── align.py ├── finish.py ├── hdr_plus.kv ├── main.py ├── merge.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ -------------------------------------------------------------------------------- /Examples/Banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/792x/HDR-Plus-Python/b50daad519ca9c4cb40d38ff4497ee1d6ccb13ed/Examples/Banner.png -------------------------------------------------------------------------------- /Examples/chairs_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/792x/HDR-Plus-Python/b50daad519ca9c4cb40d38ff4497ee1d6ccb13ed/Examples/chairs_input.jpg -------------------------------------------------------------------------------- /Examples/chairs_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/792x/HDR-Plus-Python/b50daad519ca9c4cb40d38ff4497ee1d6ccb13ed/Examples/chairs_output.jpg -------------------------------------------------------------------------------- /Examples/empty_gui_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/792x/HDR-Plus-Python/b50daad519ca9c4cb40d38ff4497ee1d6ccb13ed/Examples/empty_gui_v2.png -------------------------------------------------------------------------------- /Examples/flowers_in.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/792x/HDR-Plus-Python/b50daad519ca9c4cb40d38ff4497ee1d6ccb13ed/Examples/flowers_in.jpg -------------------------------------------------------------------------------- /Examples/flowers_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/792x/HDR-Plus-Python/b50daad519ca9c4cb40d38ff4497ee1d6ccb13ed/Examples/flowers_out.jpg -------------------------------------------------------------------------------- /Examples/input_sunflower.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/792x/HDR-Plus-Python/b50daad519ca9c4cb40d38ff4497ee1d6ccb13ed/Examples/input_sunflower.jpg -------------------------------------------------------------------------------- /Examples/output_sunflower.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/792x/HDR-Plus-Python/b50daad519ca9c4cb40d38ff4497ee1d6ccb13ed/Examples/output_sunflower.jpg -------------------------------------------------------------------------------- /Examples/progress_bar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/792x/HDR-Plus-Python/b50daad519ca9c4cb40d38ff4497ee1d6ccb13ed/Examples/progress_bar.png -------------------------------------------------------------------------------- /HDRPlus2020.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/792x/HDR-Plus-Python/b50daad519ca9c4cb40d38ff4497ee1d6ccb13ed/HDRPlus2020.pdf -------------------------------------------------------------------------------- /Images/gallery.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/792x/HDR-Plus-Python/b50daad519ca9c4cb40d38ff4497ee1d6ccb13ed/Images/gallery.jpg -------------------------------------------------------------------------------- /Output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Image Banner](https://github.com/792x/HDR-Plus-Python/blob/master/Examples/Banner.png) 2 | 3 | # HDR Plus Python 4 | Implementation with GUI for desktop of Google's HDR+ in Python using Halide bindings. 5 | 6 | This repository is provided as is and is not maintained. 7 | 8 | For our whitepaper see [https://github.com/792x/HDR-Plus-Python/blob/master/HDRPlus2020.pdf](https://github.com/792x/HDR-Plus-Python/blob/master/HDRPlus2020.pdf) 9 | 10 | For the original paper see [https://www.hdrplusdata.org/hdrplus.pdf](https://www.hdrplusdata.org/hdrplus.pdf) 11 | 12 | ## Data 13 | This project uses the [HDR+ Burst Photography Dataset](http://www.hdrplusdata.org/dataset.html). 14 | To download the subset of bursts used in this project, download the [Google Cloud SDK](https://cloud.google.com/sdk/docs/#install_the_latest_cloud_sdk_version) and use the following command: 15 | ``` 16 | gsutil -m cp -r gs://hdrplusdata/20171106_subset . 17 | ``` 18 | 19 | ## Prerequisites 20 | * Linux or MacOS 21 | * [LLVM](http://llvm.org/releases/download.html) 22 | * [Halide](https://github.com/halide/Halide) 23 | * [python_bindings](https://github.com/halide/Halide/tree/master/python_bindings) 24 | 25 | `pip install -r requirements.txt` 26 | 27 | ## Examples 28 | Input | Output 29 | :-------------------------:|:-------------------------: 30 | ![Image Flowers_In](https://github.com/792x/HDR-Plus-Python/blob/master/Examples/flowers_in.jpg) | ![Image Flowers Out](https://github.com/792x/HDR-Plus-Python/blob/master/Examples/flowers_out.jpg) 31 | ![Image Chairs In](https://github.com/792x/HDR-Plus-Python/blob/master/Examples/chairs_input.jpg) | ![Image Chairs Out](https://github.com/792x/HDR-Plus-Python/blob/master/Examples/chairs_output.jpg) 32 | ![Image Sunflower In](https://github.com/792x/HDR-Plus-Python/blob/master/Examples/input_sunflower.jpg) | ![Image Sunflower Out](https://github.com/792x/HDR-Plus-Python/blob/master/Examples/output_sunflower.jpg) 33 | 34 | ## Graphical User Interface 35 | ![Image GUI](https://github.com/792x/HDR-Plus-Python/blob/master/Examples/empty_gui_v2.png) 36 | ![Image GUI Progress Bar](https://github.com/792x/HDR-Plus-Python/blob/master/Examples/progress_bar.png) 37 | 38 | 39 | ## Footnote 40 | This project was inspired by [https://github.com/timothybrooks/hdr-plus](https://github.com/timothybrooks/hdr-plus) 41 | 42 | -------------------------------------------------------------------------------- /align.py: -------------------------------------------------------------------------------- 1 | import math 2 | from datetime import datetime 3 | import halide as hl 4 | from utils import time_diff, Point, gaussian_down4, box_down2, prev_tile, idx_layer, TILE_SIZE_2, DOWNSAMPLE_RATE 5 | 6 | ''' 7 | Determines the best offset for tiles of the image at a given resolution, 8 | provided the offsets for the layer above 9 | 10 | layer : Halide buffer 11 | The downsampled layer for which the offset needs to be calculated 12 | This is a layer of the four-level gaussian pyramid 13 | prev_alignment : Halide function 14 | Alignment of the previous layer 15 | prev_min : Point 16 | Min search region 17 | prev_max : Point 18 | Max search region 19 | 20 | Returns: Halide function representing an alignment of the current layer 21 | ''' 22 | def align_layer(layer, prev_alignment, prev_min, prev_max): 23 | scores = hl.Func(layer.name() + "_scores") 24 | alignment = hl.Func(layer.name() + "_alignment") 25 | xi, yi, tx, ty, n = hl.Var("xi"), hl.Var("yi"), hl.Var('tx'), hl.Var('ty'), hl.Var('n') 26 | rdom0 = hl.RDom([(0, 16), (0, 16)]) 27 | rdom1 = hl.RDom([(-4, 8), (-4, 8)]) 28 | 29 | # Alignment of the previous (more coarse) layer scaled to this (finer) layer 30 | prev_offset = DOWNSAMPLE_RATE * Point(prev_alignment[prev_tile(tx), prev_tile(ty), n]).clamp(prev_min, prev_max) 31 | 32 | x0 = idx_layer(tx, rdom0.x) 33 | y0 = idx_layer(ty, rdom0.y) 34 | # (x,y) coordinates in the search region relative to the offset obtained from the alignment of the previous layer 35 | x = x0 + prev_offset.x + xi 36 | y = y0 + prev_offset.y + yi 37 | 38 | ref_val = layer[x0, y0, 0] # Value of reference frame (the first frame) 39 | alt_val = layer[x, y, n] # alternate frame value 40 | 41 | # L1 distance between reference frame and alternate frame 42 | d = hl.abs(hl.cast(hl.Int(32), ref_val) - hl.cast(hl.Int(32), alt_val)) 43 | 44 | scores[xi, yi, tx, ty, n] = hl.sum(d) 45 | 46 | # Alignment for each tile, where L1 distances are minimum 47 | alignment[tx, ty, n] = Point(hl.argmin(scores[rdom1.x, rdom1.y, tx, ty, n])) + prev_offset 48 | 49 | scores.compute_at(alignment, tx).vectorize(xi, 8) 50 | 51 | alignment.compute_root().parallel(ty).vectorize(tx, 16) 52 | 53 | return alignment 54 | 55 | 56 | ''' 57 | Step 1 of HDR+ pipeline: align 58 | Creates a gaussian pyramid of downsampled images converted to grayscale. 59 | Uses first frame as reference. 60 | 61 | images : Halide buffer 62 | The raw burst frames 63 | 64 | Returns: Halide function representing an alignment of the burst frames 65 | ''' 66 | def align_images(images): 67 | print(f'\n{"=" * 30}\nAligning images...\n{"=" * 30}') 68 | start = datetime.utcnow() 69 | 70 | alignment_3 = hl.Func("layer_3_alignment") 71 | alignment = hl.Func("alignment") 72 | 73 | tx, ty, n = hl.Var('tx'), hl.Var('ty'), hl.Var('n') 74 | 75 | print('Subsampling image layers...') 76 | imgs_mirror = hl.BoundaryConditions.mirror_interior(images, [(0, images.width()), (0, images.height())]) 77 | # Each consecutive layer is downsampled by a factor of 4 (2 in both x- and y-dimensions) 78 | layer_0 = box_down2(imgs_mirror, "layer_0") 79 | layer_1 = gaussian_down4(layer_0, "layer_1") 80 | layer_2 = gaussian_down4(layer_1, "layer_2") 81 | 82 | # Search regions 83 | min_search = Point(-4, -4) 84 | max_search = Point(3, 3) 85 | 86 | min_3 = Point(0, 0) 87 | min_2 = DOWNSAMPLE_RATE * min_3 + min_search 88 | min_1 = DOWNSAMPLE_RATE * min_2 + min_search 89 | 90 | max_3 = Point(0, 0) 91 | max_2 = DOWNSAMPLE_RATE * max_3 + max_search 92 | max_1 = DOWNSAMPLE_RATE * max_2 + max_search 93 | 94 | print('Aligning layers...') 95 | alignment_3[tx, ty, n] = Point(0, 0) # Initial alignment (0,0) 96 | 97 | # Align layers of the gaussian pyramid from coarse to fine 98 | # Pass previous alignment as initial guess for alignment 99 | alignment_2 = align_layer(layer_2, alignment_3, min_3, max_3) 100 | alignment_1 = align_layer(layer_1, alignment_2, min_2, max_2) 101 | alignment_0 = align_layer(layer_0, alignment_1, min_1, max_1) 102 | 103 | num_tx = math.floor(images.width() / TILE_SIZE_2 - 1) # number of tiles 104 | num_ty = math.floor(images.height() / TILE_SIZE_2 - 1) 105 | 106 | alignment[tx, ty, n] = 2 * Point(alignment_0[tx, ty, n]) # alignment of the original image 107 | 108 | alignment_repeat = hl.BoundaryConditions.repeat_edge(alignment, [(0, num_tx), (0, num_ty)]) 109 | 110 | print(f'Alignment finished in {time_diff(start)} ms.\n') 111 | return alignment_repeat 112 | -------------------------------------------------------------------------------- /finish.py: -------------------------------------------------------------------------------- 1 | import math 2 | import halide as hl 3 | from utils import DENOISE_PASSES, TONE_MAP_PASSES, SHARPEN_STRENGTH 4 | 5 | 6 | def black_white_level(input, black_point, white_point): 7 | output = hl.Func("black_white_level_output") 8 | 9 | x, y = hl.Var("x"), hl.Var("y") 10 | 11 | white_factor = 65535 / (white_point - black_point) 12 | 13 | output[x, y] = hl.u16_sat((hl.i32(input[x, y]) - black_point) * white_factor) 14 | 15 | return output 16 | 17 | 18 | def white_balance(input, width, height, white_balance_r, white_balance_g0, white_balance_g1, white_balance_b): 19 | output = hl.Func("white_balance_output") 20 | 21 | print(width, height, white_balance_r, white_balance_g0, white_balance_g1, white_balance_b) 22 | 23 | x, y = hl.Var("x"), hl.Var("y") 24 | 25 | rdom = hl.RDom([(0, width / 2), (0, height / 2)]) 26 | 27 | output[x, y] = hl.u16(0) 28 | 29 | output[rdom.x * 2, rdom.y * 2] = hl.u16_sat(white_balance_r * hl.f32(input[rdom.x * 2, rdom.y * 2])) 30 | output[rdom.x * 2 + 1, rdom.y * 2] = hl.u16_sat(white_balance_g0 * hl.f32(input[rdom.x * 2 + 1, rdom.y * 2])) 31 | output[rdom.x * 2, rdom.y * 2 + 1] = hl.u16_sat(white_balance_g1 * hl.f32(input[rdom.x * 2, rdom.y * 2 + 1])) 32 | output[rdom.x * 2 + 1, rdom.y * 2 + 1] = hl.u16_sat(white_balance_b * hl.f32(input[rdom.x * 2 + 1, rdom.y * 2 + 1])) 33 | 34 | output.compute_root().parallel(y).vectorize(x, 16) 35 | 36 | output.update(0).parallel(rdom.y) 37 | output.update(1).parallel(rdom.y) 38 | output.update(2).parallel(rdom.y) 39 | output.update(3).parallel(rdom.y) 40 | 41 | return output 42 | 43 | 44 | def demosaic(input, width, height): 45 | print(f'width: {width}, height: {height}') 46 | 47 | f0 = hl.Buffer(hl.Int(32), [5, 5], "demosaic_f0") 48 | f1 = hl.Buffer(hl.Int(32), [5, 5], "demosaic_f1") 49 | f2 = hl.Buffer(hl.Int(32), [5, 5], "demosaic_f2") 50 | f3 = hl.Buffer(hl.Int(32), [5, 5], "demosaic_f3") 51 | 52 | f0.translate([-2, -2]) 53 | f1.translate([-2, -2]) 54 | f2.translate([-2, -2]) 55 | f3.translate([-2, -2]) 56 | 57 | d0 = hl.Func("demosaic_0") 58 | d1 = hl.Func("demosaic_1") 59 | d2 = hl.Func("demosaic_2") 60 | d3 = hl.Func("demosaic_3") 61 | 62 | output = hl.Func("demosaic_output") 63 | 64 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 65 | rdom0 = hl.RDom([(-2, 5), (-2, 5)]) 66 | # rdom1 = hl.RDom([(0, width / 2), (0, height / 2)]) 67 | 68 | input_mirror = hl.BoundaryConditions.mirror_interior(input, [(0, width), (0, height)]) 69 | 70 | f0.fill(0) 71 | f1.fill(0) 72 | f2.fill(0) 73 | f3.fill(0) 74 | 75 | f0_sum = 8 76 | f1_sum = 16 77 | f2_sum = 16 78 | f3_sum = 16 79 | 80 | f0[0, -2] = -1 81 | f0[0, -1] = 2 82 | f0[-2, 0] = -1 83 | f0[-1, 0] = 2 84 | f0[0, 0] = 4 85 | f0[1, 0] = 2 86 | f0[2, 0] = -1 87 | f0[0, 1] = 2 88 | f0[0, 2] = -1 89 | 90 | f1[0, -2] = 1 91 | f1[-1, -1] = -2 92 | f1[1, -1] = -2 93 | f1[-2, 0] = -2 94 | f1[-1, 0] = 8 95 | f1[0, 0] = 10 96 | f1[1, 0] = 8 97 | f1[2, 0] = -2 98 | f1[-1, 1] = -2 99 | f1[1, 1] = -2 100 | f1[0, 2] = 1 101 | 102 | f2[0, -2] = -2 103 | f2[-1, -1] = -2 104 | f2[0, -1] = 8 105 | f2[1, -1] = -2 106 | f2[-2, 0] = 1 107 | f2[0, 0] = 10 108 | f2[2, 0] = 1 109 | f2[-1, 1] = -2 110 | f2[0, 1] = 8 111 | f2[1, 1] = -2 112 | f2[0, 2] = -2 113 | 114 | f3[0, -2] = -3 115 | f3[-1, -1] = 4 116 | f3[1, -1] = 4 117 | f3[-2, 0] = -3 118 | f3[0, 0] = 12 119 | f3[2, 0] = -3 120 | f3[-1, 1] = 4 121 | f3[1, 1] = 4 122 | f3[0, 2] = -3 123 | 124 | d0[x, y] = hl.u16_sat(hl.sum(hl.i32(input_mirror[x + rdom0.x, y + rdom0.y]) * f0[rdom0.x, rdom0.y]) / f0_sum) 125 | d1[x, y] = hl.u16_sat(hl.sum(hl.i32(input_mirror[x + rdom0.x, y + rdom0.y]) * f1[rdom0.x, rdom0.y]) / f1_sum) 126 | d2[x, y] = hl.u16_sat(hl.sum(hl.i32(input_mirror[x + rdom0.x, y + rdom0.y]) * f2[rdom0.x, rdom0.y]) / f2_sum) 127 | d3[x, y] = hl.u16_sat(hl.sum(hl.i32(input_mirror[x + rdom0.x, y + rdom0.y]) * f3[rdom0.x, rdom0.y]) / f3_sum) 128 | 129 | R_row = y % 2 == 0 130 | B_row = y % 2 != 0 131 | R_col = x % 2 == 0 132 | B_col = x % 2 != 0 133 | at_R = c == 0 134 | at_G = c == 1 135 | at_B = c == 2 136 | 137 | output[x, y, c] = hl.select(at_R & R_row & B_col, d1[x, y], 138 | at_R & B_row & R_col, d2[x, y], 139 | at_R & B_row & B_col, d3[x, y], 140 | at_G & R_row & R_col, d0[x, y], 141 | at_G & B_row & B_col, d0[x, y], 142 | at_B & B_row & R_col, d1[x, y], 143 | at_B & R_row & B_col, d2[x, y], 144 | at_B & R_row & R_col, d3[x, y], 145 | input[x, y]) 146 | 147 | d0.compute_root().parallel(y).vectorize(x, 16) 148 | d1.compute_root().parallel(y).vectorize(x, 16) 149 | d2.compute_root().parallel(y).vectorize(x, 16) 150 | d3.compute_root().parallel(y).vectorize(x, 16) 151 | 152 | output.compute_root().parallel(y).align_bounds(x, 2).unroll(x, 2).align_bounds(y, 2).unroll(y, 2).vectorize(x, 16) 153 | 154 | return output 155 | 156 | 157 | def rgb_to_yuv(input): 158 | print(' rgb_to_yuv') 159 | 160 | output = hl.Func("rgb_to_yuv_output") 161 | 162 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 163 | 164 | rdom = input[x, y, 0] 165 | g = input[x, y, 1] 166 | b = input[x, y, 2] 167 | 168 | output[x, y, c] = hl.f32(0) 169 | 170 | output[x, y, 0] = 0.2989 * rdom + 0.587 * g + 0.114 * b 171 | output[x, y, 1] = -0.168935 * rdom - 0.331655 * g + 0.50059 * b 172 | output[x, y, 2] = 0.499813 * rdom - 0.418531 * g + - 0.081282 * b 173 | 174 | output.compute_root().parallel(y).vectorize(x, 16) 175 | 176 | output.update(0).parallel(y).vectorize(x, 16) 177 | output.update(1).parallel(y).vectorize(x, 16) 178 | output.update(2).parallel(y).vectorize(x, 16) 179 | 180 | return output 181 | 182 | 183 | def bilateral_filter(input, width, height): 184 | print(' bilateral_filter') 185 | 186 | k = hl.Buffer(hl.Float(32), [7, 7], "gauss_kernel") 187 | k.translate([-3, -3]) 188 | 189 | weights = hl.Func("bilateral_weights") 190 | total_weights = hl.Func("bilateral_total_weights") 191 | bilateral = hl.Func("bilateral") 192 | output = hl.Func("bilateral_filter_output") 193 | 194 | x, y, dx, dy, c = hl.Var("x"), hl.Var("y"), hl.Var("dx"), hl.Var("dy"), hl.Var("c") 195 | rdom = hl.RDom([(-3, 7), (-3, 7)]) 196 | 197 | k.fill(0) 198 | k[-3, -3] = 0.000690 199 | k[-2, -3] = 0.002646 200 | k[-1, -3] = 0.005923 201 | k[0, -3] = 0.007748 202 | k[1, -3] = 0.005923 203 | k[2, -3] = 0.002646 204 | k[3, -3] = 0.000690 205 | k[-3, -2] = 0.002646 206 | k[-2, -2] = 0.010149 207 | k[-1, -2] = 0.022718 208 | k[0, -2] = 0.029715 209 | k[1, -2] = 0.022718 210 | k[2, -2] = 0.010149 211 | k[3, -2] = 0.002646 212 | k[-3, -1] = 0.005923 213 | k[-2, -1] = 0.022718 214 | k[-1, -1] = 0.050855 215 | k[0, -1] = 0.066517 216 | k[1, -1] = 0.050855 217 | k[2, -1] = 0.022718 218 | k[3, -1] = 0.005923 219 | k[-3, 0] = 0.007748 220 | k[-2, 0] = 0.029715 221 | k[-1, 0] = 0.066517 222 | k[0, 0] = 0.087001 223 | k[1, 0] = 0.066517 224 | k[2, 0] = 0.029715 225 | k[3, 0] = 0.007748 226 | k[-3, 1] = 0.005923 227 | k[-2, 1] = 0.022718 228 | k[-1, 1] = 0.050855 229 | k[0, 1] = 0.066517 230 | k[1, 1] = 0.050855 231 | k[2, 1] = 0.022718 232 | k[3, 1] = 0.005923 233 | k[-3, 2] = 0.002646 234 | k[-2, 2] = 0.010149 235 | k[-1, 2] = 0.022718 236 | k[0, 2] = 0.029715 237 | k[1, 2] = 0.022718 238 | k[2, 2] = 0.010149 239 | k[3, 2] = 0.002646 240 | k[-3, 3] = 0.000690 241 | k[-2, 3] = 0.002646 242 | k[-1, 3] = 0.005923 243 | k[0, 3] = 0.007748 244 | k[1, 3] = 0.005923 245 | k[2, 3] = 0.002646 246 | k[3, 3] = 0.000690 247 | 248 | input_mirror = hl.BoundaryConditions.mirror_interior(input, [(0, width), (0, height)]) 249 | 250 | dist = hl.cast(hl.Float(32), 251 | hl.cast(hl.Int(32), input_mirror[x, y, c]) - hl.cast(hl.Int(32), input_mirror[x + dx, y + dy, c])) 252 | 253 | sig2 = 100 254 | 255 | threshold = 25000 256 | 257 | score = hl.select(hl.abs(input_mirror[x + dx, y + dy, c]) > threshold, 0, hl.exp(-dist * dist / sig2)) 258 | 259 | weights[dx, dy, x, y, c] = k[dx, dy] * score 260 | 261 | total_weights[x, y, c] = hl.sum(weights[rdom.x, rdom.y, x, y, c]) 262 | 263 | bilateral[x, y, c] = hl.sum(input_mirror[x + rdom.x, y + rdom.y, c] * weights[rdom.x, rdom.y, x, y, c]) / \ 264 | total_weights[x, y, c] 265 | 266 | output[x, y, c] = hl.cast(hl.Float(32), input[x, y, c]) 267 | 268 | output[x, y, 1] = bilateral[x, y, 1] 269 | output[x, y, 2] = bilateral[x, y, 2] 270 | 271 | weights.compute_at(output, y).vectorize(x, 16) 272 | 273 | output.compute_root().parallel(y).vectorize(x, 16) 274 | 275 | output.update(0).parallel(y).vectorize(x, 16) 276 | output.update(1).parallel(y).vectorize(x, 16) 277 | 278 | return output 279 | 280 | 281 | def gauss_15x15(input, name): 282 | print(' gauss_15x15') 283 | 284 | k = hl.Buffer(hl.Float(32), [15], "gauss_15x15") 285 | k.translate([-7]) 286 | 287 | rdom = hl.RDom([(-7, 15)]) 288 | 289 | k.fill(0) 290 | k[-7] = 0.004961 291 | k[-6] = 0.012246 292 | k[-5] = 0.026304 293 | k[-4] = 0.049165 294 | k[-3] = 0.079968 295 | k[-2] = 0.113193 296 | k[-1] = 0.139431 297 | k[0] = 0.149464 298 | k[7] = 0.004961 299 | k[6] = 0.012246 300 | k[5] = 0.026304 301 | k[4] = 0.049165 302 | k[3] = 0.079968 303 | k[2] = 0.113193 304 | k[1] = 0.139431 305 | 306 | return gauss(input, k, rdom, name) 307 | 308 | 309 | def desaturate_noise(input, width, height): 310 | print(' desaturate_noise') 311 | 312 | output = hl.Func("desaturate_noise_output") 313 | 314 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 315 | 316 | input_mirror = hl.BoundaryConditions.mirror_image(input, [(0, width), (0, height)]) 317 | 318 | blur = gauss_15x15(gauss_15x15(input_mirror, "desaturate_noise_blur1"), "desaturate_noise_blur_2") 319 | 320 | factor = 1.4 321 | 322 | threshold = 25000 323 | 324 | output[x, y, c] = input[x, y, c] 325 | 326 | output[x, y, 1] = hl.select((hl.abs(blur[x, y, 1]) / hl.abs(input[x, y, 1]) < factor) & 327 | (hl.abs(input[x, y, 1]) < threshold) & (hl.abs(blur[x, y, 1]) < threshold), 328 | 0.7 * blur[x, y, 1] + 0.3 * input[x, y, 1], input[x, y, 1]) 329 | 330 | output[x, y, 2] = hl.select((hl.abs(blur[x, y, 2]) / hl.abs(input[x, y, 2]) < factor) & 331 | (hl.abs(input[x, y, 2]) < threshold) & (hl.abs(blur[x, y, 2]) < threshold), 332 | 0.7 * blur[x, y, 2] + 0.3 * input[x, y, 2], input[x, y, 2]) 333 | 334 | output.compute_root().parallel(y).vectorize(x, 16) 335 | 336 | return output 337 | 338 | 339 | def increase_saturation(input, strength): 340 | print(' increase saturation') 341 | 342 | output = hl.Func("increase_saturation_output") 343 | 344 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 345 | 346 | output[x, y, c] = strength * input[x, y, c] 347 | output[x, y, 0] = input[x, y, 0] 348 | 349 | output.compute_root().parallel(y).vectorize(x, 16) 350 | 351 | return output 352 | 353 | 354 | def yuv_to_rgb(input): 355 | print(' yuv_to_rgb') 356 | 357 | output = hl.Func("yuv_to_rgb_output") 358 | 359 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 360 | 361 | Y = input[x, y, 0] 362 | U = input[x, y, 1] 363 | V = input[x, y, 2] 364 | 365 | output[x, y, c] = hl.cast(hl.UInt(16), 0) 366 | 367 | output[x, y, 0] = hl.u16_sat(Y + 1.403 * V) 368 | output[x, y, 1] = hl.u16_sat(Y - 0.344 * U - 0.714 * V) 369 | output[x, y, 2] = hl.u16_sat(Y + 1.77 * U) 370 | 371 | output.compute_root().parallel(y).vectorize(x, 16) 372 | 373 | output.update(0).parallel(y).vectorize(x, 16) 374 | output.update(1).parallel(y).vectorize(x, 16) 375 | output.update(2).parallel(y).vectorize(x, 16) 376 | 377 | return output 378 | 379 | 380 | def chroma_denoise(input, width, height, denoise_passes): 381 | print(f'width: {width}, height: {height}, passes: {denoise_passes}') 382 | 383 | output = rgb_to_yuv(input) 384 | 385 | p = 0 386 | 387 | if denoise_passes > 0: 388 | output = bilateral_filter(output, width, height) 389 | p += 1 390 | 391 | while p < denoise_passes: 392 | output = desaturate_noise(output, width, height) 393 | p += 1 394 | 395 | if denoise_passes > 2: 396 | output = increase_saturation(output, 1.1) 397 | 398 | return yuv_to_rgb(output) 399 | 400 | 401 | def srgb(input, ccm): 402 | srgb_matrix = hl.Func("srgb_matrix") 403 | output = hl.Func("srgb_output") 404 | 405 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 406 | 407 | rdom = hl.RDom([(0, 3)]) 408 | 409 | srgb_matrix[x, y] = hl.f32(0) 410 | 411 | srgb_matrix[0, 0] = hl.f32(ccm[0][0]) 412 | srgb_matrix[1, 0] = hl.f32(ccm[0][1]) 413 | srgb_matrix[2, 0] = hl.f32(ccm[0][2]) 414 | srgb_matrix[0, 1] = hl.f32(ccm[1][0]) 415 | srgb_matrix[1, 1] = hl.f32(ccm[1][1]) 416 | srgb_matrix[2, 1] = hl.f32(ccm[1][2]) 417 | srgb_matrix[0, 2] = hl.f32(ccm[2][0]) 418 | srgb_matrix[1, 2] = hl.f32(ccm[2][1]) 419 | srgb_matrix[2, 2] = hl.f32(ccm[2][2]) 420 | 421 | output[x, y, c] = hl.u16_sat(hl.sum(srgb_matrix[rdom, c] * input[x, y, rdom])) 422 | 423 | return output 424 | 425 | 426 | def gauss(input, k, rdom, name): 427 | blur_x = hl.Func(name + "_x") 428 | output = hl.Func(name) 429 | 430 | x, y, c, xi, yi = hl.Var("x"), hl.Var("y"), hl.Var("c"), hl.Var("xi"), hl.Var("yi") 431 | 432 | val = hl.Expr("val") 433 | 434 | if input.dimensions() == 2: 435 | blur_x[x, y] = hl.sum(input[x + rdom, y] * k[rdom]) 436 | val = hl.sum(blur_x[x, y + rdom] * k[rdom]) 437 | if input.output_types()[0] == hl.UInt(16): 438 | val = hl.u16(val) 439 | output[x, y] = val 440 | else: 441 | blur_x[x, y, c] = hl.sum(input[x + rdom, y, c] * k[rdom]) 442 | val = hl.sum(blur_x[x, y + rdom, c] * k[rdom]) 443 | if input.output_types()[0] == hl.UInt(16): 444 | val = hl.u16(val) 445 | output[x, y, c] = val 446 | 447 | blur_x.compute_at(output, x).vectorize(x, 16) 448 | 449 | output.compute_root().tile(x, y, xi, yi, 256, 128).vectorize(xi, 16).parallel(y) 450 | 451 | return output 452 | 453 | 454 | def gauss_7x7(input, name): 455 | k = hl.Buffer(hl.Float(32), [7], "gauss_7x7_kernel") 456 | k.translate([-3]) 457 | 458 | rdom = hl.RDom([(-3, 7)]) 459 | 460 | k.fill(0) 461 | k[-3] = 0.026267 462 | k[-2] = 0.100742 463 | k[-1] = 0.225511 464 | k[0] = 0.29496 465 | k[1] = 0.225511 466 | k[2] = 0.100742 467 | k[3] = 0.026267 468 | 469 | return gauss(input, k, rdom, name) 470 | 471 | 472 | def diff(im1, im2, name): 473 | output = hl.Func(name) 474 | 475 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 476 | 477 | if im1.dimensions() == 2: 478 | output[x, y] = hl.i32(im1[x, y]) - hl.i32(im2[x, y]) 479 | else: 480 | output[x, y, c] = hl.i32(im1[x, y, c]) - hl.i32(im2[x, y, c]) 481 | 482 | return output 483 | 484 | 485 | def combine(im1, im2, width, height, dist): 486 | init_mask1 = hl.Func("mask1_layer_0") 487 | init_mask2 = hl.Func("mask2_layer_0") 488 | accumulator = hl.Func("combine_accumulator") 489 | output = hl.Func("combine_output") 490 | 491 | x, y = hl.Var("x"), hl.Var("y") 492 | 493 | im1_mirror = hl.BoundaryConditions.repeat_edge(im1, [(0, width), (0, height)]) 494 | im2_mirror = hl.BoundaryConditions.repeat_edge(im2, [(0, width), (0, height)]) 495 | 496 | unblurred1 = im1_mirror 497 | unblurred2 = im2_mirror 498 | 499 | blurred1 = gauss_7x7(im1_mirror, "img1_layer_0") 500 | blurred2 = gauss_7x7(im2_mirror, "img2_layer_0") 501 | 502 | weight1 = hl.f32(dist[im1_mirror[x, y]]) 503 | weight2 = hl.f32(dist[im2_mirror[x, y]]) 504 | 505 | init_mask1[x, y] = weight1 / (weight1 + weight2) 506 | init_mask2[x, y] = 1 - init_mask1[x, y] 507 | 508 | mask1 = init_mask1 509 | mask2 = init_mask2 510 | 511 | num_layers = 2 512 | 513 | accumulator[x, y] = hl.i32(0) 514 | 515 | for i in range(1, num_layers): 516 | print(' layer', i) 517 | 518 | prev_layer_str = str(i - 1) 519 | layer_str = str(i) 520 | 521 | laplace1 = diff(unblurred1, blurred1, "laplace1_layer_" + prev_layer_str) 522 | laplace2 = diff(unblurred2, blurred2, "laplace2_layer_" + layer_str) 523 | 524 | accumulator[x, y] += hl.i32(laplace1[x, y] * mask1[x, y]) + hl.i32(laplace2[x, y] * mask2[x, y]) 525 | 526 | unblurred1 = blurred1 527 | unblurred2 = blurred2 528 | 529 | blurred1 = gauss_7x7(blurred1, "img1_layer_" + layer_str) 530 | blurred2 = gauss_7x7(blurred2, "img2_layer_" + layer_str) 531 | 532 | mask1 = gauss_7x7(mask1, "mask1_layer_" + layer_str) 533 | mask2 = gauss_7x7(mask2, "mask2_layer_" + layer_str) 534 | 535 | accumulator[x, y] += hl.i32(blurred1[x, y] * mask1[x, y]) + hl.i32(blurred2[x, y] * mask2[x, y]) 536 | 537 | output[x, y] = hl.u16_sat(accumulator[x, y]) 538 | 539 | init_mask1.compute_root().parallel(y).vectorize(x, 16) 540 | 541 | accumulator.compute_root().parallel(y).vectorize(x, 16) 542 | 543 | for i in range(num_layers): 544 | accumulator.update(i).parallel(y).vectorize(x, 16) 545 | 546 | return output 547 | 548 | 549 | def combine2(im1, im2, width, height, dist): 550 | init_mask1 = hl.Func("mask1_layer_0") 551 | init_mask2 = hl.Func("mask2_layer_0") 552 | accumulator = hl.Func("combine_accumulator") 553 | output = hl.Func("combine_output") 554 | 555 | x, y = hl.Var("x"), hl.Var("y") 556 | 557 | im1_mirror = hl.BoundaryConditions.repeat_edge(im1, [(0, width), (0, height)]) 558 | im2_mirror = hl.BoundaryConditions.repeat_edge(im2, [(0, width), (0, height)]) 559 | 560 | weight1 = hl.f32(dist[im1_mirror[x, y]]) 561 | weight2 = hl.f32(dist[im2_mirror[x, y]]) 562 | 563 | init_mask1[x, y] = weight1 / (weight1 + weight2) 564 | init_mask2[x, y] = 1 - init_mask1[x, y] 565 | 566 | mask1 = init_mask1 567 | mask2 = init_mask2 568 | 569 | accumulator[x, y] = hl.i32(0) 570 | 571 | accumulator[x, y] += hl.i32(im1_mirror[x, y] * mask1[x, y]) + hl.i32(im2_mirror[x, y] * mask2[x, y]) 572 | 573 | output[x, y] = hl.u16_sat(accumulator[x, y]) 574 | 575 | init_mask1.compute_root().parallel(y).vectorize(x, 16) 576 | 577 | accumulator.compute_root().parallel(y).vectorize(x, 16) 578 | 579 | accumulator.update(0).parallel(y).vectorize(x, 16) 580 | 581 | return output 582 | 583 | 584 | def brighten(input, gain): 585 | output = hl.Func("brighten_output") 586 | 587 | x, y = hl.Var("x"), hl.Var("y") 588 | 589 | output[x, y] = hl.u16_sat(gain * hl.u32(input[x, y])) 590 | 591 | return output 592 | 593 | 594 | def gamma_correct(input): 595 | output = hl.Func("gamma_correct_output") 596 | 597 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 598 | 599 | cutoff = 200 600 | gamma_toe = 12.92 601 | gamma_pow = 0.416667 602 | gamma_fac = 680.552897 603 | gamma_con = -3604.425 604 | 605 | if input.dimensions() == 2: 606 | output[x, y] = hl.u16(hl.select(input[x, y] < cutoff, 607 | gamma_toe * input[x, y], 608 | gamma_fac * hl.pow(input[x, y], gamma_pow) + gamma_con)) 609 | else: 610 | output[x, y, c] = hl.u16(hl.select(input[x, y, c] < cutoff, 611 | gamma_toe * input[x, y, c], 612 | gamma_fac * hl.pow(input[x, y, c], gamma_pow) + gamma_con)) 613 | 614 | output.compute_root().parallel(y).vectorize(x, 16) 615 | 616 | return output 617 | 618 | 619 | def gamma_inverse(input): 620 | output = hl.Func("gamma_inverse_output") 621 | 622 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 623 | 624 | cutoff = 2575 625 | gamma_toe = 0.0774 626 | gamma_pow = 2.4 627 | gamma_fac = 57632.49226 628 | gamma_con = 0.055 629 | 630 | if input.dimensions() == 2: 631 | output[x, y] = hl.u16(hl.select(input[x, y] < cutoff, 632 | gamma_toe * input[x, y], 633 | hl.pow(hl.f32(input[x, y]) / 65535 + gamma_con, gamma_pow) * gamma_fac)) 634 | else: 635 | output[x, y, c] = hl.u16(hl.select(input[x, y, c] < cutoff, 636 | gamma_toe * input[x, y, c], 637 | hl.pow(hl.f32(input[x, y, c]) / 65535 + gamma_con, gamma_pow) * gamma_fac)) 638 | 639 | output.compute_root().parallel(y).vectorize(x, 16) 640 | 641 | return output 642 | 643 | 644 | def tone_map(input, width, height, compression, gain): 645 | print(f'Compression: {compression}, gain: {gain}') 646 | 647 | normal_dist = hl.Func("luma_weight_distribution") 648 | grayscale = hl.Func("grayscale") 649 | output = hl.Func("tone_map_output") 650 | 651 | x, y, c, v = hl.Var("x"), hl.Var("y"), hl.Var("c"), hl.Var("v") 652 | 653 | rdom = hl.RDom([(0, 3)]) 654 | 655 | normal_dist[v] = hl.f32(hl.exp(-12.5 * hl.pow(hl.f32(v) / 65535 - 0.5, 2))) 656 | 657 | grayscale[x, y] = hl.u16(hl.sum(hl.u32(input[x, y, rdom])) / 3) 658 | 659 | dark = grayscale 660 | 661 | comp_const = 1 662 | gain_const = 1 663 | 664 | comp_slope = (compression - comp_const) / (TONE_MAP_PASSES) 665 | gain_slope = (gain - gain_const) / (TONE_MAP_PASSES) 666 | 667 | for i in range(TONE_MAP_PASSES): 668 | print(' pass', i) 669 | 670 | norm_comp = i * comp_slope + comp_const 671 | norm_gain = i * gain_slope + gain_const 672 | 673 | bright = brighten(dark, norm_comp) 674 | 675 | dark_gamma = gamma_correct(dark) 676 | bright_gamma = gamma_correct(bright) 677 | 678 | dark_gamma = combine2(dark_gamma, bright_gamma, width, height, normal_dist) 679 | 680 | dark = brighten(gamma_inverse(dark_gamma), norm_gain) 681 | 682 | output[x, y, c] = hl.u16_sat(hl.u32(input[x, y, c]) * hl.u32(dark[x, y]) / hl.u32(hl.max(1, grayscale[x, y]))) 683 | 684 | grayscale.compute_root().parallel(y).vectorize(x, 16) 685 | 686 | normal_dist.compute_root().vectorize(v, 16) 687 | 688 | return output 689 | 690 | 691 | def shift_bayer_to_rggb(input, cfa_pattern): 692 | print(f'cfa_pattern: {cfa_pattern}') 693 | output = hl.Func("rggb_input") 694 | x, y = hl.Var("x"), hl.Var("y") 695 | 696 | cfa = hl.u16(cfa_pattern) 697 | 698 | output[x, y] = hl.select(cfa == hl.u16(1), input[x, y], 699 | cfa == hl.u16(2), input[x + 1, y], 700 | cfa == hl.u16(4), input[x, y + 1], 701 | cfa == hl.u16(3), input[x + 1, y + 1], 702 | 0) 703 | return output 704 | 705 | 706 | def contrast(input, strength, black_point): 707 | output = hl.Func("contrast_output") 708 | 709 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 710 | 711 | scale = strength 712 | 713 | inner_constant = math.pi / (2 * scale) 714 | sin_constant = hl.sin(inner_constant) 715 | slope = 65535 / (2 * sin_constant) 716 | constant = slope * sin_constant 717 | factor = math.pi / (scale * 65535) 718 | 719 | val = factor * hl.cast(hl.Float(32), input[x, y, c]) 720 | 721 | output[x, y, c] = hl.u16_sat(slope * hl.sin(val - inner_constant) + constant) 722 | 723 | white_scale = 65535 / (65535 - black_point) 724 | 725 | output[x, y, c] = hl.u16_sat((hl.cast(hl.Int(32), output[x, y, c]) - black_point) * white_scale) 726 | 727 | output.compute_root().parallel(y).vectorize(x, 16) 728 | 729 | return output 730 | 731 | 732 | def sharpen(input, strength): 733 | output_yuv = hl.Func("sharpen_output") 734 | 735 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 736 | 737 | yuv_input = rgb_to_yuv(input) 738 | 739 | small_blurred = gauss_7x7(yuv_input, "unsharp_small_blur") 740 | large_blurred = gauss_7x7(small_blurred, "unsharp_large_blur") 741 | 742 | difference_of_gauss = diff(small_blurred, large_blurred, "unsharp_DoG") 743 | 744 | output_yuv[x, y, c] = yuv_input[x, y, c] 745 | output_yuv[x, y, 0] = yuv_input[x, y, 0] + strength * difference_of_gauss[x, y, 0] 746 | 747 | output = yuv_to_rgb(output_yuv) 748 | 749 | output_yuv.compute_root().parallel(y).vectorize(x, 16) 750 | 751 | return output 752 | 753 | 754 | def u8bit_interleave(input): 755 | output = hl.Func("8bit_interleaved_output") 756 | 757 | x, y, c = hl.Var("x"), hl.Var("y"), hl.Var("c") 758 | 759 | output[x, y, c] = hl.u8_sat(input[x, y, c] / 256) 760 | 761 | output.compute_root().parallel(y).vectorize(x, 16) 762 | 763 | return output 764 | 765 | 766 | ''' 767 | Step 3 of HDR+ pipeline: finish 768 | Finishes the merged burst images using the following steps: 769 | 1 : Shift Bayer image to RGGB 770 | 2 : Black- and white-level correction 771 | 3 : White balancing 772 | 4 : Demosaicing 773 | 5 : Chroma denoising 774 | 6 : sRGB color correction 775 | 7 : Tone mapping (global) 776 | 8 : Gamma correction 777 | 9 : Contrast adjustment 778 | 10 : Sharpening 779 | 11 : 8-bit interleaving 780 | 781 | image : Halide buffer 782 | The merged image to be finished 783 | width : Integer 784 | Width of the image 785 | height : Integer 786 | Height of the image 787 | black_point : Integer 788 | Black level of the image to be used for black and white level correction 789 | white_point : Integer 790 | White level of the image to be used for black and white level correction 791 | white_balance_x : Float 792 | White balance value for color X (R, G, G, B) 793 | compression : Float 794 | Compression value to be used for tone mapping 795 | gain : Float 796 | Gain value to be used for tone mapping 797 | contrast_strength : Float 798 | Contrast value to be used for contrast adjustment 799 | cfa_pattern : Integer 800 | Represents the Bayer pattern of the image, used to shift Bayer to RGGB 801 | ccm : numpy.ndarray of shape (3, 4) 802 | Color correction matrix, used for sRGB color correction 803 | 804 | Returns: Halide buffer (finished image) 805 | ''' 806 | def finish_image(image, width, height, black_point, white_point, white_balance_r, white_balance_g0, white_balance_g1, 807 | white_balance_b, compression, gain, contrast_strength, cfa_pattern, ccm): 808 | print(black_point, white_point, white_balance_r, white_balance_g0, white_balance_g1, 809 | white_balance_b, compression, gain) 810 | 811 | print("bayer_to_rggb") 812 | bayer_shifted = shift_bayer_to_rggb(image, cfa_pattern) 813 | 814 | print("black_white_level") 815 | black_white_level_output = black_white_level(bayer_shifted, black_point, white_point) 816 | 817 | print("white_balance") 818 | white_balance_output = white_balance(black_white_level_output, width, height, white_balance_r, white_balance_g0, 819 | white_balance_g1, white_balance_b) 820 | 821 | print("demosaic") 822 | demosaic_output = demosaic(white_balance_output, width, height) 823 | 824 | print('chroma_denoise') 825 | chroma_denoised_output = chroma_denoise(demosaic_output, width, height, DENOISE_PASSES) 826 | 827 | print("srgb") 828 | srgb_output = srgb(chroma_denoised_output, ccm) 829 | 830 | print("tone_map") 831 | tone_map_output = tone_map(srgb_output, width, height, compression, gain) 832 | 833 | print("gamma_correct") 834 | gamma_correct_output = gamma_correct(tone_map_output) 835 | 836 | print('contrast') 837 | contrast_output = contrast(gamma_correct_output, contrast_strength, black_point) 838 | 839 | print('sharpen') 840 | sharpen_output = sharpen(contrast_output, SHARPEN_STRENGTH) 841 | 842 | print('u8bit_interleave') 843 | u8bit_interleave_output = u8bit_interleave(sharpen_output) 844 | 845 | return u8bit_interleave_output 846 | -------------------------------------------------------------------------------- /hdr_plus.kv: -------------------------------------------------------------------------------- 1 | #:kivy 1.1.0 2 | 3 | Root: 4 | BoxLayout: 5 | orientation: 'vertical' 6 | BoxLayout: 7 | orientation: 'horizontal' 8 | Image: 9 | id: image0 10 | source: root.original 11 | size_hint_x: 1 12 | allow_stretch: False 13 | keep_data: True 14 | Image: 15 | id: image1 16 | source: root.image 17 | size_hint_x: 1 18 | allow_stretch: False 19 | keep_data: True 20 | BoxLayout: 21 | orientation: 'vertical' 22 | size_hint_x: 0.4 23 | Label: 24 | text: 'Compression' 25 | BoxLayout: 26 | orientation: 'horizontal' 27 | Slider: 28 | id: compression 29 | value: 3.8 30 | min: 1 31 | max: 6.6 32 | step: 0.1 33 | orientation: 'horizontal' 34 | Label: 35 | size_hint_x: 0.3 36 | text: str(compression.value)[:3] 37 | Label: 38 | text: 'Gain' 39 | BoxLayout: 40 | orientation: 'horizontal' 41 | Slider: 42 | id: gain 43 | value: 1.1 44 | min: 0.1 45 | max: 2.1 46 | step: 0.01 47 | orientation: 'horizontal' 48 | Label: 49 | size_hint_x: 0.3 50 | text: str(gain.value)[:4] 51 | Label: 52 | text: 'Contrast' 53 | BoxLayout: 54 | orientation: 'horizontal' 55 | Slider: 56 | id: contrast 57 | value: 1.0 58 | min: 0.5 59 | max: 1.5 60 | step: 0.01 61 | orientation: 'horizontal' 62 | Label: 63 | size_hint_x: 0.3 64 | text: str(contrast.value)[:4] 65 | Label: 66 | id: state 67 | text: '' 68 | size_hint_y: 8 69 | BoxLayout: 70 | size_hint_y: None 71 | height: 30 72 | Button: 73 | text: 'Select burst' 74 | on_release: root.show_load() 75 | Button: 76 | text: 'Process' 77 | on_release: root.process() 78 | 79 | : 80 | BoxLayout: 81 | size: root.size 82 | pos: root.pos 83 | orientation: "vertical" 84 | FileChooserListView: 85 | id: filechooser 86 | 87 | BoxLayout: 88 | size_hint_y: None 89 | height: 30 90 | Button: 91 | text: "Cancel" 92 | on_release: root.cancel() 93 | 94 | Button: 95 | text: "Select" 96 | on_release: root.load(filechooser.path, filechooser.selection) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import rawpy 3 | import imageio 4 | import os 5 | import multiprocessing 6 | import halide as hl 7 | from datetime import datetime 8 | import threading 9 | from functools import partial 10 | os.environ['KIVY_NO_CONSOLELOG'] = '1' # Comment this line when debugging UI 11 | from kivy.app import App 12 | from kivy.uix.floatlayout import FloatLayout 13 | from kivy.factory import Factory 14 | from kivy.properties import ObjectProperty 15 | from kivy.uix.popup import Popup 16 | from kivy.uix.label import Label 17 | from kivy.uix.button import Button 18 | from kivy.uix.progressbar import ProgressBar 19 | from kivy.clock import Clock 20 | from utils import time_diff 21 | from align import align_images 22 | from merge import merge_images 23 | from finish import finish_image 24 | 25 | ''' 26 | Loads a raw image 27 | 28 | image_path : str 29 | String representing the path to the image 30 | 31 | Returns: numpy ndarray with 4 values for each pixel (RGGB) 32 | ''' 33 | def load_image(image_path): 34 | with rawpy.imread(image_path) as raw: 35 | image = raw.raw_image_visible.copy() 36 | return image 37 | 38 | 39 | ''' 40 | Decode a raw CFA pattern 41 | 42 | pattern : list of lists of integers (numpy.ndarray) 43 | RawPy.raw_pattern: the smallest possible Bayer pattern of a raw image 44 | 45 | Returns: Integer in range 1 - 4, where 46 | 1 : RGGB 47 | 2 : GRBG 48 | 3 : BGGR 49 | 4 : RGBG 50 | ''' 51 | def decode_pattern(pattern): 52 | pattern_str = "" 53 | for row in pattern: 54 | for val in row: 55 | if val == 0: 56 | pattern_str += 'R' 57 | elif val == 1: 58 | pattern_str += 'G' 59 | elif val == 2: 60 | pattern_str += 'B' 61 | else: 62 | pattern_str += 'G' 63 | if pattern_str == 'RGGB': 64 | return 1 65 | elif pattern_str == 'GRBG': 66 | return 2 67 | elif pattern_str == 'BGGR': 68 | return 3 69 | else: 70 | return 4 71 | 72 | 73 | ''' 74 | Loads a burst of images 75 | 76 | burst_path : str 77 | String representing the path to the folder containing the burst images 78 | 79 | Returns: Halide buffer of raw images, reference image, white balance values for RGGB, 80 | black level, white level, CFA pattern, color correction matrix 81 | ''' 82 | def load_images(burst_path): 83 | print(f'\n{"=" * 30}\nLoading images...\n{"=" * 30}') 84 | start = datetime.utcnow() 85 | images = [] 86 | white_balance_r = 0 87 | white_balance_g0 = 0 88 | white_balance_g1 = 0 89 | white_balance_b = 0 90 | black_point = 0 91 | white_point = 0 92 | cfa_pattern = 0 93 | 94 | # Create list of paths to the images 95 | paths = [] 96 | for i in range(100): 97 | if i < 10: 98 | filename = f'payload_N00{i}.dng' 99 | else: 100 | filename = f'payload_N0{i}.dng' 101 | file_path = f'{burst_path}/{filename}' 102 | if os.path.isfile(file_path): 103 | paths.append(file_path) 104 | else: 105 | if i == 0: 106 | raise ValueError("Burst format not recognized.") 107 | break 108 | 109 | # Load raw images 110 | print('Loading raw images...') 111 | p = multiprocessing.Pool(min(multiprocessing.cpu_count() - 1, len(paths))) 112 | for image in p.imap(load_image, paths): 113 | images.append(hl.Buffer(image)) 114 | 115 | assert len(images) >= 2, "Burst must consist of at least 2 images" 116 | 117 | # Get a reference image to compare results 118 | print('Getting reference image...') 119 | with rawpy.imread(paths[0]) as raw: 120 | white_balance = raw.camera_whitebalance 121 | print('white balance', white_balance) 122 | white_balance_r = white_balance[0] / white_balance[1] 123 | white_balance_g0 = 1 124 | white_balance_g1 = 1 125 | white_balance_b = white_balance[2] / white_balance[1] 126 | cfa_pattern = raw.raw_pattern 127 | cfa_pattern = decode_pattern(cfa_pattern) 128 | ccm = raw.color_matrix 129 | black_point = int(raw.black_level_per_channel[0]) 130 | white_point = int(raw.white_level) 131 | 132 | ref_img = raw.postprocess(output_bps=16) 133 | 134 | print('Building image buffer...') 135 | result = hl.Buffer(hl.UInt(16), [images[0].width(), images[0].height(), len(images)]) 136 | for index, image in enumerate(images): 137 | resultSlice = result.sliced(2, index) 138 | resultSlice.copy_from(image) 139 | 140 | print(f'Loading finished in {time_diff(start)} ms.\n') 141 | return result, ref_img, white_balance_r, white_balance_g0, white_balance_g1, white_balance_b, black_point, white_point, cfa_pattern, ccm 142 | 143 | 144 | ''' 145 | Main method of the HDR+ pipeline: align, merge, finish 146 | 147 | burst_path : str 148 | The path to the folder containing the burst images 149 | compression : float 150 | Compression to be used in finish step 151 | gain : float 152 | Gain to be used in finish step 153 | contrast : float 154 | Contrast to be used in finish step 155 | UI : Root(FloatLayout) class object 156 | Kivy object used to update UI elements 157 | 158 | After execution finishes, UI.original and UI.image will be set to the reference frame of the input, 159 | and the result of the burst processed by the HDR+ pipeline, respectively. 160 | 161 | If an error is encountered, these values will instead remain unchanged, and an error will be passed to the UI. 162 | ''' 163 | def HDR(burst_path, compression, gain, contrast, UI): 164 | try: 165 | start = datetime.utcnow() 166 | 167 | print(f'Compression: {compression}, gain: {gain}, contrast: {contrast}') 168 | 169 | # Load the images 170 | images, ref_img, white_balance_r, white_balance_g0, white_balance_g1, white_balance_b, black_point, white_point, cfa_pattern, ccm = load_images( 171 | burst_path) 172 | Clock.schedule_once(partial(UI.update_progress, 20)) 173 | 174 | # dimensions of image should be 3 175 | assert images.dimensions() == 3, f"Incorrect buffer dimensions, expected 3 but got {images.dimensions()}" 176 | assert images.dim(2).extent() >= 2, f"Must have at least one alternate image" 177 | # Save the reference image 178 | imageio.imsave('Output/input.jpg', ref_img) 179 | 180 | # Align the images 181 | alignment = align_images(images) 182 | 183 | # Merge the images 184 | merged = merge_images(images, alignment) 185 | 186 | # Finish the image 187 | print(f'\n{"=" * 30}\nFinishing image...\n{"=" * 30}') 188 | start_finish = datetime.utcnow() 189 | finished = finish_image(merged, images.width(), images.height(), black_point, white_point, white_balance_r, 190 | white_balance_g0, white_balance_g1, white_balance_b, compression, gain, contrast, 191 | cfa_pattern, ccm) 192 | 193 | Clock.schedule_once(partial(UI.update_progress, 30)) 194 | 195 | result = finished.realize(images.width(), images.height(), 3) 196 | 197 | Clock.schedule_once(partial(UI.update_progress, 90)) 198 | 199 | print(f'Finishing finished in {time_diff(start_finish)} ms.\n') 200 | 201 | # If portrait orientation, rotate image 90 degrees clockwise 202 | print(ref_img.shape) 203 | if ref_img.shape[0] > ref_img.shape[1]: 204 | print('Rotating image') 205 | result = np.rot90(result, -1) 206 | 207 | imageio.imsave('Output/output.jpg', result) 208 | 209 | Clock.schedule_once(partial(UI.update_progress, 100)) 210 | 211 | print(f'Processed in: {time_diff(start)} ms') 212 | 213 | # return 'Output/input.jpg', 'Output/output.jpg' 214 | 215 | Clock.schedule_once(partial(UI.update_paths, 'Output/input.jpg', 'Output/output.jpg')) 216 | 217 | Clock.schedule_once(UI.dismiss_progress) 218 | 219 | except Exception as e: 220 | Clock.schedule_once(partial(UI.show_error, e)) 221 | 222 | 223 | class Imglayout(FloatLayout): 224 | def __init__(self, **args): 225 | super(Imglayout, self).__init__(**args) 226 | 227 | with self.canvas.before: 228 | Color(0, 0, 0, 0) 229 | self.rect = Rectangle(size=self.size, pos=self.pos) 230 | 231 | self.bind(size=self.updates, pos=self.updates) 232 | 233 | def updates(self, instance, value): 234 | self.rect.size = instance.size 235 | self.rect.pos = instance.pos 236 | 237 | 238 | class LoadDialog(FloatLayout): 239 | load = ObjectProperty(None) 240 | cancel = ObjectProperty(None) 241 | 242 | 243 | class Root(FloatLayout): 244 | loadfile = ObjectProperty(None) 245 | progress_bar = ObjectProperty() 246 | progress_popup = None 247 | 248 | # Empty gallery images 249 | original = 'Images/gallery.jpg' 250 | image = 'Images/gallery.jpg' 251 | 252 | # Path to the burst images 253 | path = '' 254 | 255 | cancelled = False 256 | 257 | compression = 3.8 258 | gain = 1.1 259 | contrast = 1.0 260 | 261 | def build(): 262 | c = Imglayout() 263 | root.add_widget(c) 264 | 265 | def dismiss_popup(self): 266 | self._popup.dismiss() 267 | 268 | def dismiss_progress(self, *largs): 269 | self.progress_popup.dismiss() 270 | 271 | def update_progress(self, num, *largs): 272 | self.progress_bar.value = num 273 | 274 | def update_paths(self, input_path, output_path, *largs): 275 | self.original = input_path 276 | self.image = output_path 277 | 278 | def reload_images(self, instance): 279 | self.ids.image0.source = self.original 280 | self.ids.image0.reload() 281 | self.ids.image1.source = self.image 282 | self.ids.image1.reload() 283 | 284 | def next(self, dt): 285 | if self.progress_bar.value >= 100: 286 | return False 287 | self.progress_bar.value += 1 288 | 289 | def show_error(self, error, *largs): 290 | if self.progress_popup: 291 | self.dismiss_progress() 292 | txt = '\n'.join(str(error)[i:i + 80] for i in range(0, len(str(error)), 80)) 293 | float_popup = FloatLayout(size_hint=(0.9, .04)) 294 | float_popup.add_widget(Label(text=txt, 295 | size_hint=(0.7, 1), 296 | pos_hint={'x': 0.15, 'y': 12})) 297 | float_popup.add_widget(Button(text='Close', 298 | on_press=lambda *args: popup.dismiss(), 299 | size_hint=(0.2, 4), 300 | pos_hint={'x': 0.4, 'y': 1})) 301 | popup = Popup(title='Error', 302 | content=float_popup, 303 | size_hint=(0.9, 0.4)) 304 | popup.open() 305 | 306 | # Function to call the HDR+ pipeline 307 | def process(self): 308 | try: 309 | if not self.path: 310 | raise ValueError('No burst selected.') 311 | # Get slider values for compression, gain, and contrast 312 | self.compression = self.ids.compression.value 313 | self.gain = self.ids.gain.value 314 | self.contrast = self.ids.contrast.value 315 | 316 | self.progress_bar = ProgressBar() 317 | self.progress_popup = Popup(title=f'Processing {self.path}', 318 | content=self.progress_bar, 319 | size_hint=(0.7, 0.2), 320 | auto_dismiss=False) 321 | self.progress_popup.bind(on_dismiss=self.reload_images) 322 | self.progress_bar.value = 1 323 | self.progress_popup.open() 324 | Clock.schedule_interval(self.next, 0.1) 325 | 326 | HDR_thread = threading.Thread(target=HDR, 327 | args=(self.path, self.compression, self.gain, self.contrast, self,)) 328 | HDR_thread.start() 329 | 330 | except Exception as e: 331 | self.show_error(e) 332 | 333 | def show_load(self): 334 | content = LoadDialog(load=self.load, cancel=self.dismiss_popup) 335 | self._popup = Popup(title="Select burst image", content=content, 336 | size_hint=(0.9, 0.9)) 337 | 338 | self._popup.open() 339 | 340 | def load(self, path, filename): 341 | # Set the path to the burst images 342 | self.path = path 343 | self.cancelled = False 344 | self.dismiss_popup() 345 | 346 | def cancel(self): 347 | self.cancelled = True 348 | self.dismiss_popup() 349 | 350 | 351 | class HDR_Plus(App): 352 | pass 353 | 354 | 355 | Factory.register('Root', cls=Root) 356 | Factory.register('LoadDialog', cls=LoadDialog) 357 | 358 | if __name__ == '__main__': 359 | HDR_Plus().run() 360 | -------------------------------------------------------------------------------- /merge.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import halide as hl 3 | from utils import time_diff, Point, box_down2, idx_layer, idx_im, idx_0, idx_1, tile_0, tile_1, TILE_SIZE, \ 4 | MINIMUM_OFFSET, MAXIMUM_OFFSET 5 | import math 6 | 7 | ''' 8 | Merges images in the temporal dimension. 9 | Weights different frames based on the similarity between the reference tile 10 | and the alternate tiles, minimizing L1 distances (least absolute deviation). 11 | Distances greater than some threshold (max_distance) are discarded. 12 | 13 | images : Halide buffer 14 | Burst frames to be merged 15 | alignment : Halide function 16 | Calculated alignment of the burst frames 17 | 18 | Returns: Halide buffer (merged image) 19 | ''' 20 | def merge_temporal(images, alignment): 21 | weight = hl.Func("merge_temporal_weights") 22 | total_weight = hl.Func("merge_temporal_total_weights") 23 | output = hl.Func("merge_temporal_output") 24 | 25 | ix, iy, tx, ty, n = hl.Var('ix'), hl.Var('iy'), hl.Var('tx'), hl.Var('ty'), hl.Var('n') 26 | rdom0 = hl.RDom([(0, 16), (0, 16)]) 27 | 28 | rdom1 = hl.RDom([(1, images.dim(2).extent() - 1)]) 29 | 30 | imgs_mirror = hl.BoundaryConditions.mirror_interior(images, [(0, images.width()), (0, images.height())]) 31 | 32 | layer = box_down2(imgs_mirror, "merge_layer") 33 | 34 | offset = Point(alignment[tx, ty, n]).clamp(Point(MINIMUM_OFFSET, MINIMUM_OFFSET), 35 | Point(MAXIMUM_OFFSET, MAXIMUM_OFFSET)) 36 | 37 | al_x = idx_layer(tx, rdom0.x) + offset.x / 2 38 | al_y = idx_layer(ty, rdom0.y) + offset.y / 2 39 | 40 | ref_val = layer[idx_layer(tx, rdom0.x), idx_layer(ty, rdom0.y), 0] 41 | alt_val = layer[al_x, al_y, n] 42 | 43 | factor = 8.0 44 | min_distance = 10 45 | max_distance = 300 # max L1 distance, otherwise the value is not used 46 | 47 | distance = hl.sum(hl.abs(hl.cast(hl.Int(32), ref_val) - hl.cast(hl.Int(32), alt_val))) / 256 48 | 49 | normal_distance = hl.max(1, hl.cast(hl.Int(32), distance) / factor - min_distance / factor) 50 | 51 | # Weight for the alternate frame 52 | weight[tx, ty, n] = hl.select(normal_distance > (max_distance - min_distance), 0.0, 53 | 1.0 / normal_distance) 54 | 55 | total_weight[tx, ty] = hl.sum(weight[tx, ty, rdom1]) + 1 56 | 57 | offset = Point(alignment[tx, ty, rdom1]) 58 | 59 | al_x = idx_im(tx, ix) + offset.x 60 | al_y = idx_im(ty, iy) + offset.y 61 | 62 | ref_val = imgs_mirror[idx_im(tx, ix), idx_im(ty, iy), 0] 63 | alt_val = imgs_mirror[al_x, al_y, rdom1] 64 | 65 | # Sum all values according to their weight, and divide by total weight to obtain average 66 | output[ix, iy, tx, ty] = hl.sum(weight[tx, ty, rdom1] * alt_val / total_weight[tx, ty]) + ref_val / total_weight[ 67 | tx, ty] 68 | 69 | weight.compute_root().parallel(ty).vectorize(tx, 16) 70 | 71 | total_weight.compute_root().parallel(ty).vectorize(tx, 16) 72 | 73 | output.compute_root().parallel(ty).vectorize(ix, 32) 74 | 75 | return output 76 | 77 | 78 | ''' 79 | Merges images in the spatial dimension. 80 | Smoothly blends overlapping tiles using modified raised cosine window. 81 | 82 | input : Halide buffer 83 | image (burst frames merged in temporal dimension) 84 | 85 | Returns: Halide buffer (merged image) 86 | ''' 87 | def merge_spatial(input): 88 | weight = hl.Func("raised_cosine_weights") 89 | output = hl.Func("merge_spatial_output") 90 | 91 | v, x, y = hl.Var('v'), hl.Var('x'), hl.Var('y') 92 | 93 | # modified raised cosine window 94 | weight[v] = 0.5 - 0.5 * hl.cos(2 * math.pi * (v + 0.5) / TILE_SIZE) 95 | 96 | weight_00 = weight[idx_0(x)] * weight[idx_0(y)] 97 | weight_10 = weight[idx_1(x)] * weight[idx_0(y)] 98 | weight_01 = weight[idx_0(x)] * weight[idx_1(y)] 99 | weight_11 = weight[idx_1(x)] * weight[idx_1(y)] 100 | 101 | val_00 = input[idx_0(x), idx_0(y), tile_0(x), tile_0(y)] 102 | val_10 = input[idx_1(x), idx_0(y), tile_1(x), tile_0(y)] 103 | val_01 = input[idx_0(x), idx_1(y), tile_0(x), tile_1(y)] 104 | val_11 = input[idx_1(x), idx_1(y), tile_1(x), tile_1(y)] 105 | 106 | output[x, y] = hl.cast(hl.UInt(16), weight_00 * val_00 107 | + weight_10 * val_10 108 | + weight_01 * val_01 109 | + weight_11 * val_11) 110 | 111 | weight.compute_root().vectorize(v, 32) 112 | 113 | output.compute_root().parallel(y).vectorize(x, 32) 114 | 115 | return output 116 | 117 | 118 | ''' 119 | Step 2 of HDR+ pipeline: merge 120 | 121 | images : Halide buffer 122 | Burst frames to be merged 123 | alignment : Halide function 124 | Calculated alignment of the burst frames 125 | 126 | Returns: Halide buffer (merged image) 127 | ''' 128 | def merge_images(images, alignment): 129 | print(f'\n{"=" * 30}\nMerging images...\n{"=" * 30}') 130 | start = datetime.utcnow() 131 | merge_temporal_output = merge_temporal(images, alignment) 132 | merge_spatial_output = merge_spatial(merge_temporal_output) 133 | 134 | print(f'Merging finished in {time_diff(start)} ms.\n') 135 | return merge_spatial_output 136 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.17.4 2 | opencv-python==4.2.0.32 3 | rawpy==0.14 4 | imageio==2.6.1 5 | kivy==1.11.1 6 | docutils==0.16 7 | pygments==2.5.2 8 | setuptools==45.1.0 9 | wheel==0.34.2 10 | kivy-deps.glew==0.2.0 11 | kivy-deps.sdl2==0.2.0 12 | pypiwin32==223 13 | kivy-deps.gstreamer==0.2.0 14 | kivy-deps.angle==0.2.0 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import halide as hl 3 | 4 | # Global constants 5 | TILE_SIZE = 32 6 | TILE_SIZE_2 = 16 7 | MINIMUM_OFFSET = -168 8 | MAXIMUM_OFFSET = 126 9 | DOWNSAMPLE_RATE = 4 10 | DENOISE_PASSES = 1 11 | SHARPEN_STRENGTH = 2 12 | TONE_MAP_PASSES = 4 13 | 14 | ''' 15 | Get the difference between a start and end time (or current time if none given) in ms 16 | 17 | start : datetime 18 | Start time 19 | end : datetime 20 | End time 21 | 22 | Returns: int 23 | ''' 24 | def time_diff(start, end=None): 25 | if not end: 26 | end = datetime.utcnow() 27 | return int((end - start).total_seconds() * 1000) 28 | 29 | 30 | ''' 31 | Point object which stores the coordinates x and y 32 | 33 | x : float 34 | x-coordinate 35 | y : float 36 | y-coordinate 37 | ''' 38 | class Point: 39 | def __init__(self, x=None, y=None): 40 | if x is None and y is None: 41 | self.x = hl.cast(hl.Int(16), 0) 42 | self.y = hl.cast(hl.Int(16), 0) 43 | elif x is not None and y is None: 44 | if type(x) is hl.FuncRef: 45 | hl.Tuple(x) 46 | self.x = hl.cast(hl.Int(16), x[0]) 47 | self.y = hl.cast(hl.Int(16), x[1]) 48 | elif type(x) is tuple: 49 | self.x = hl.cast(hl.Int(16), x[0]) 50 | self.y = hl.cast(hl.Int(16), x[1]) 51 | else: 52 | self.x = hl.cast(hl.Int(16), x) 53 | self.y = hl.cast(hl.Int(16), y) 54 | 55 | def get_tuple(self): 56 | return self.x, self.y 57 | 58 | def clamp(self, min_p, max_p): 59 | return Point(hl.clamp(self.x, min_p.x, max_p.x), hl.clamp(self.y, min_p.y, max_p.y)) 60 | 61 | def __len__(self): 62 | return 2 63 | 64 | def __getitem__(self, idx): 65 | return (self.x, self.y)[idx] 66 | 67 | # Point addition 68 | def __add__(self, p): 69 | return Point(self.x + p.x, self.y + p.y) 70 | 71 | # Point subtraction 72 | def __sub__(self, p): 73 | return Point(self.x - p.x, self.y - p.y) 74 | 75 | # Scalar multiplication 76 | def __mul__(self, n: int): 77 | return Point(self.x * n, self.y * n) 78 | 79 | # Scalar multiplication with self on the right-hand side 80 | def __rmul__(self, n: int): 81 | return Point(self.x * n, self.y * n) 82 | 83 | # Point negation 84 | def __neg__(self): 85 | return Point(-self.x, -self.y) 86 | 87 | 88 | def gaussian_down4(input, name): 89 | output = hl.Func(name) 90 | k = hl.Func(name + "_filter") 91 | x, y, n = hl.Var("x"), hl.Var("y"), hl.Var('n') 92 | rdom = hl.RDom([(-2, 5), (-2, 5)]) 93 | 94 | k[x, y] = 0 95 | k[-2, -2] = 2 96 | k[-1, -2] = 4 97 | k[0, -2] = 5 98 | k[1, -2] = 4 99 | k[2, -2] = 2 100 | k[-2, -1] = 4 101 | k[-1, -1] = 9 102 | k[0, -1] = 12 103 | k[1, -1] = 9 104 | k[2, -1] = 4 105 | k[-2, 0] = 5 106 | k[-1, 0] = 12 107 | k[0, 0] = 15 108 | k[1, 0] = 12 109 | k[2, 0] = 5 110 | k[-2, 1] = 4 111 | k[-1, 1] = 9 112 | k[0, 1] = 12 113 | k[1, 1] = 9 114 | k[2, 1] = 4 115 | k[-2, 2] = 2 116 | k[-1, 2] = 4 117 | k[0, 2] = 5 118 | k[1, 2] = 4 119 | k[2, 2] = 2 120 | 121 | output[x, y, n] = hl.cast(hl.UInt(16), 122 | hl.sum(hl.cast(hl.UInt(32), input[4 * x + rdom.x, 4 * y + rdom.y, n] * k[rdom.x, rdom.y])) 123 | / 159) 124 | 125 | k.compute_root().parallel(y).parallel(x) 126 | output.compute_root().parallel(y).vectorize(x, 16) 127 | 128 | return output 129 | 130 | 131 | def box_down2(input, name): 132 | output = hl.Func(name) 133 | 134 | x, y, n = hl.Var("x"), hl.Var("y"), hl.Var('n') 135 | rdom = hl.RDom([(0, 2), (0, 2)]) 136 | 137 | output[x, y, n] = hl.cast(hl.UInt(16), hl.sum(hl.cast(hl.UInt(32), input[2 * x + rdom.x, 2 * y + rdom.y, n])) / 4) 138 | 139 | output.compute_root().parallel(y).vectorize(x, 16) 140 | 141 | return output 142 | 143 | 144 | def prev_tile(t): 145 | return (t - 1) / DOWNSAMPLE_RATE 146 | 147 | 148 | def idx_layer(t, i): 149 | return t * TILE_SIZE_2 / 2 + i 150 | 151 | 152 | def idx_im(t, i): 153 | return t * TILE_SIZE_2 + i 154 | 155 | 156 | def idx_0(e): 157 | return e % TILE_SIZE_2 + TILE_SIZE_2 158 | 159 | 160 | def idx_1(e): 161 | return e % TILE_SIZE_2 162 | 163 | 164 | def tile_0(e): 165 | return e / TILE_SIZE_2 - 1 166 | 167 | 168 | def tile_1(e): 169 | return e / TILE_SIZE_2 170 | --------------------------------------------------------------------------------