├── .gitignore ├── img ├── mid_grid.png ├── 8_way_smoothing.png ├── 8_way_small_overlap.png ├── no_overlap_gridding.jpg ├── 25p_overlap_gridding.jpg ├── comparison_biomedical.jpg ├── 25p_overlap_gridding_lg.jpg └── no_overlap_gridding_lg.jpg ├── environment.dev.yml ├── LICENSE ├── README.md ├── example.py └── seamless_seg.py /.gitignore: -------------------------------------------------------------------------------- 1 | envs 2 | **/__pycache__ 3 | vis 4 | .idea 5 | -------------------------------------------------------------------------------- /img/mid_grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Multihuntr/seamless-seg/HEAD/img/mid_grid.png -------------------------------------------------------------------------------- /img/8_way_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Multihuntr/seamless-seg/HEAD/img/8_way_smoothing.png -------------------------------------------------------------------------------- /img/8_way_small_overlap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Multihuntr/seamless-seg/HEAD/img/8_way_small_overlap.png -------------------------------------------------------------------------------- /img/no_overlap_gridding.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Multihuntr/seamless-seg/HEAD/img/no_overlap_gridding.jpg -------------------------------------------------------------------------------- /img/25p_overlap_gridding.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Multihuntr/seamless-seg/HEAD/img/25p_overlap_gridding.jpg -------------------------------------------------------------------------------- /img/comparison_biomedical.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Multihuntr/seamless-seg/HEAD/img/comparison_biomedical.jpg -------------------------------------------------------------------------------- /img/25p_overlap_gridding_lg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Multihuntr/seamless-seg/HEAD/img/25p_overlap_gridding_lg.jpg -------------------------------------------------------------------------------- /img/no_overlap_gridding_lg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Multihuntr/seamless-seg/HEAD/img/no_overlap_gridding_lg.jpg -------------------------------------------------------------------------------- /environment.dev.yml: -------------------------------------------------------------------------------- 1 | name: seamless 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - nvidia 6 | dependencies: 7 | - numpy 8 | - shapely 9 | - scipy 10 | - scikit-image 11 | - pillow 12 | - pytorch 13 | - torchvision 14 | - pytorch-cuda=11.8 15 | - torchgeo 16 | - timm 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Brandon Victor 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `seamless_seg`: Seamless tiled segmentation postprocessing tools for large images 2 | 3 | Typical strategies for segmenting large images involve tiling. Unfortunately this can cause visible, obviously incorrect seams between tiles. This repo provides postprocessing functions which gracefully remove *all* such tiling artifacts for any segmentation task. 4 | 5 | ![Example](img/comparison_biomedical.jpg) 6 | 7 | * :white_check_mark: Optimal! No more tiling artifacts. Guaranteed seamless segmentation for any segmentation model. 8 | * :floppy_disk: Efficient! Needs <1% of image in memory at once for large images (40000x40000 and above). 9 | * :purple_circle: Decoupled! Minimal dependencies. Does not prescribe any IO libraries. 10 | * :zap: Fast! Approx 0.25ms of overhead per tile. 11 | 12 | ## Installation 13 | 14 | Copy `seamless_seg.py` into your project. 15 | 16 | Dependencies: `shapely=2.0` and `scipy`. 17 | 18 | 19 | 20 | ## Getting started 21 | 22 | If you have a very simple input-output pattern, you can use one of the convenience functions. This will automatically read/mix/write tiles with minimal memory overhead. 23 | 24 | ```python 25 | import torch 26 | 27 | import seamless_seg 28 | 29 | model = # get pytorch model from somewhere 30 | tile_size = # whatever size your model needs 31 | 32 | # ############################### 33 | # Example 1. A simple numpy array 34 | 35 | import numpy as np 36 | 37 | img = # a numpy image shaped [C, H, W] 38 | out = seamless_seg.pytorch_numpy(model, img, tile_size) 39 | # out is a argmax'd segmentation mask shaped [H, W] 40 | 41 | # ################################## 42 | # Example 2. Using rasterio and TIFs 43 | import rasterio 44 | 45 | in_fpath = # get geotiff image from somewhere 46 | out_fpath = # where to save segmentation 47 | 48 | with rasterio.open(in_fpath) as in_tif: 49 | seamless_seg.pytorch_rasterio(model, in_tif, out_fpath, tile_size) 50 | 51 | ``` 52 | 53 | If you need more control of input/output, you can use other `seamless_seg` functions. The general process is: 54 | 55 | 1. Define how to obtain logits and how to write after processing (e.g. disk IO functions). 56 | 2. Create a plan to break the image up into tiles. 57 | 3. Execute plan to read/mix/write those tiles. 58 | 59 | Here's a minimum working example; it can be copy-paste'd and run as-is. In this case, the "reading" just generates a tile with a random colour for visualisation, and "writing" is all in-memory. In a real example, you would use the logits of a segmentation algorithm, and write each tile to disk as you reached it. 60 | 61 | ```python 62 | import shapely 63 | import numpy as np 64 | 65 | import seamless_seg 66 | 67 | image_size = (1024, 1024) 68 | tile_size = (224, 224) 69 | overlap = (64, 64) 70 | 71 | # 1. Define how to obtain logits and how to write after processing 72 | def random_logits(geom: shapely.Geometry): 73 | # Creating fake data; in real use cases, should yield model logits 74 | # obtained by running your model on image data within geom 75 | tile = np.ones((*tile_size, 3), dtype=np.uint8) 76 | tile *= np.random.randint(20, 255, (3,), dtype=np.uint8) 77 | return tile 78 | 79 | out_img = np.zeros((*image_size, 3)) 80 | def write_to_numpy_array(geom, tile): 81 | # Writing to the in-memory numpy array above; 82 | # when evaluating on large files, this should write to disk instead 83 | y_slc, x_slc = seamless_seg.shape_to_slices(geom) 84 | out_img[y_slc, x_slc] = tile 85 | 86 | # 2. Create a plan to break the image up into tiles. 87 | def tile_generator(plan, get_logits): 88 | # Important: use seamless_seg.get_plan_logit_geoms to yield the correct order of tiles 89 | for index, geom in seamless_seg.get_plan_logit_geoms(plan): 90 | yield get_logits(geom) 91 | 92 | plan, grid = seamless_seg.plan_regular_grid(image_size, tile_size, overlap) 93 | in_tiles = tile_generator(plan, random_logits) 94 | 95 | # 3. Execute plan to read/mix/write those tiles. 96 | for index, out_geom, out_tile in seamless_seg.run_plan(plan, in_tiles): 97 | # Thanks to generators, this doesn't have to hold all tiles in memory at once 98 | write_to_numpy_array(out_geom, out_tile) 99 | ``` 100 | 101 | In a real example, the `in_tiles` from above should be a generator of logits. For convenience, `seamless_seg` provides a function to obtain logits from a simple pytorch model. 102 | 103 | ```python 104 | plan = # Assume we have a plan already 105 | model = # Assume we have a model from somewhere else 106 | read_tile = # Callable that takes a geometry and return a numpy array of input data 107 | # e.g. if you have an in_tif 108 | def read_tile(shp): 109 | return in_tif.read(window=shape_to_slices(shp)) 110 | 111 | in_tiles = seamless_seg.pytorch_outputs_generator(plan, model, read_tile) 112 | ``` 113 | 114 | In the above examples, `seamless_seg` was responsible for creating the tile geometries, and used these to create the `plan`. You can bring your own input geometries. 115 | 116 | ```python 117 | import seamless_seg 118 | 119 | # grid must be a np.array shaped [H, W] where each element is a shapely.Geometry 120 | # You can use seamless_seg functions to create this: 121 | # You can create a perfectly regular grid 122 | grid = seamless_seg.regular_grid(image_size, tile_size, overlap) 123 | # or, you can try to coerce a flat list of geometries into a grid shape 124 | grid = seamless_seg.coerce_to_grid(flat_list_of_geometries) 125 | # or, 126 | grid = # whatever you like 127 | 128 | # Regardless, the planning will minimise the memory footprint 129 | plan = seamless_seg.plan_from_grid(grid) 130 | ``` 131 | 132 | At the lowest level, if you just want to blend together tiles that you've loaded yourself, you can use: 133 | 134 | ```python 135 | import seamless_seg 136 | 137 | central_geom = # a shapely.Geometry 138 | central_tile = # a np.ndarray 139 | nearby_geoms = # list of shapely.Geometry 140 | nearby_tiles = # list of np.ndarrays 141 | 142 | weights = seamless_seg.overlap_weights(central_geom, nearby_geoms) 143 | _, out_tile = seamless_seg.apply_weights(central_tile, nearby_tiles, weights) 144 | ``` 145 | 146 | ## Optimisation 147 | 148 | There are many optional parameters for each function. Most of these are optimisation parameters to run faster or with less RAM. 149 | 150 | These are the most likely use cases, and how to do them: 151 | * Run only on a small patch of a large image. 152 | * Pass `area=` where available. 153 | * Extreme RAM requirements: 154 | * Pass `max_tiles=` and `disk_cache_dir=` where available. 155 | * Batching, running models in a separate thread: 156 | * Pytorch segmentation model: Pass `batch_size>1` to either `seamless_seg.pytorch_outputs_generator` or `seamless_seg.pytorch_rasterio`. 157 | * Custom segmentation: Use `seamless_seg.batched_tile_get` and `seamless_seg.threaded_batched_tile_get`. 158 | 159 | ## Explanation - Fixing tiling artifacts 160 | 161 | ### Where do tiling artifacts come from? 162 | 163 | Tiling artifacts are a result of hard boundaries between adajcent tiles. The most naive approach is to select tiles with no overlap, and just let the model predict whatever it wills. At the boundary of those tiles, models will often make significantly different predictions. This results in sharp lines in your output segmentation. 164 | 165 | ![No overlap between tiles causes sharp lines in output](img/no_overlap_gridding.jpg) 166 | 167 | This is not a model failure per se. The problem is just that the model is using a different context for pixels on one side of a boundary to the other side of that boundary. If it were given a full context around each object, it may still segment it correctly. 168 | 169 | ### Overlapping tiles 170 | 171 | Typical solutions to this will always somehow use overlapping tiles. A slightly less naive approach commonly taken is to overlap tiles, and only keep a smaller window of the outputs. That solution *reduces* tiling artifacts, but does not remove them entirely. So long as there are hard boundaries between the tiles, tiling artifacts will appear. 172 | 173 | 174 | 175 | ![Using 25% overlap still results in sharp lines](img/25p_overlap_gridding.jpg) 176 | 177 | In the extreme case, we could evaluate a tile centered on every single pixel independently and only trust that central pixel. But this involves lots of redundant calculation. We need a better solution. 178 | 179 | ### Trusting model outputs 180 | 181 | Pixels at the edge of each tile have a truncated context because of their position within the tile. This lack of context degrades model performance at the edges of each tile. 182 | 183 | 184 | 185 | In some sense, this means that we should inherently trust the pixels at the edges less than those at the centre. So, we define a map of trustworthiness of each pixel. This is simply a measure of the distance of that pixel to the centre of the tile. 186 | 187 | 188 | 189 | We can use the trust values to determine how much we should use from each overlapping tile. This gives us a distinct weighted sum at each pixel. Using a weighted sum based on distance produces a smooth transition across tiles. Pixels at the centre of an output tile come entirely from the model output for that tile, and pixels halfway between two tiles come 50% from each, etc. 190 | 191 | 192 | 193 | ![Eight-way smoothing with 50% overlap](img/8_way_smoothing.png) 194 | ![Eight-way smoothing with approx 10% overlap](img/8_way_small_overlap.png) 195 | 196 | These weights can be obtained by calling `seamless_seg.overlap_weights`, but you probably don't want to do that. It is recommended to use `seamless_seg.plan_regular_grid` instead. 197 | 198 | ## Tiling plan - optimising read/writes 199 | 200 | Utilising the overlap weights described in the previous section is not inherently linked to using a regular grid of overlapping tiles, but usually that's what we want. To make the typical use case easier, `seamless_seg` includes `seamless_seg.plan_regular_grid` to create a tiling plan in the shape of a grid. 201 | 202 | ![Grid of colour blocks, smoothly transitioning between each](img/mid_grid.png) 203 | 204 | All you need to do is provide an image size, tile size and overlap amount (in pixels). The plan is created entirely within geometric pixel space. That is, the plan is made before any real data is read from disk. This allows you to inspect the plan and optimise reading from disk (e.g. batching, threaded/async). 205 | 206 | Finally, `seamless_seg.run_plan` is provided to actually run the plan. To control memory usage, this can optionally be given a maximum number of tiles to keep in RAM at once. 207 | 208 | ### Memory 209 | 210 | Often large segmentation tasks have images that are too large to fit into RAM. So, the tiling plan includes explicit load/unload instructions. Following this plan ensures that tiles are never requested more than once **and** that the minimum number of tiles are kept in memory. For some perspective, given a (40000, 40000) image with a tile size of (256, 256) and an overlap of (64, 64), there will be at most 1.0% of the image held in memory at once. 211 | 212 | If even this is too large, you can use the `max_tiles` and `disk_cache_dir` arguments to hold as few tiles in memory as you need. Tiles beyond this limit will be cached to disk. 213 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from pathlib import Path 4 | 5 | import shapely 6 | import skimage 7 | import numpy as np 8 | 9 | import seamless_seg 10 | 11 | from PIL import Image 12 | 13 | 14 | VIS_FOLDER = Path('vis') 15 | 16 | 17 | def test_overlap_trim(): 18 | geom_a = shapely.box(0, 0, 100, 100) 19 | geom_b = shapely.box(0, 60, 100, 160) 20 | 21 | out_geom_a, _, _, _, _ = seamless_seg.overlap_weights(geom_a, [geom_b]) 22 | out_geom_b, _, _, _, _ = seamless_seg.overlap_weights(geom_b, [geom_a]) 23 | 24 | print(shapely.area(shapely.intersection(out_geom_a, out_geom_b))) 25 | # is 40*100 = 4000 26 | 27 | out_geom_a, _, _, _, _ = seamless_seg.overlap_weights(geom_a, [geom_b], (0, 0, None, -20)) 28 | out_geom_b, _, _, _, _ = seamless_seg.overlap_weights(geom_b, [geom_a], (0, 20, None, None)) 29 | 30 | print(shapely.area(shapely.intersection(out_geom_a, out_geom_b))) 31 | 32 | 33 | def show_overlap_weights_regular(): 34 | np.random.seed(123459) 35 | # Creating test data 36 | central_geom = shapely.box(200, 300, 500, 800) 37 | # Regular grid 38 | boxes = [ 39 | shapely.box(-50, -150, 250, 350), #tl 40 | shapely.box(450, -150, 750, 350), #bl 41 | shapely.box(-50, 750, 250, 1250), #tr 42 | shapely.box(450, 750, 750, 1250), #br 43 | shapely.box(200, -150, 500, 350), # l 44 | shapely.box(-50, 300, 250, 800), # t 45 | shapely.box(200, 750, 500, 1250), # r 46 | shapely.box(450, 300, 750, 800), # b 47 | ] 48 | central_tile = np.ones((300, 500, 3), dtype=np.uint8) 49 | central_tile *= np.random.randint(20, 255, (3,), dtype=np.uint8) 50 | N = len(boxes) 51 | nearby_tiles = [] 52 | for i in range(N): 53 | nearby_tiles.append(np.ones((300, 500, 3), dtype=np.uint8)) 54 | nearby_tiles[i] *= np.random.randint(20, 255, (3,), dtype=np.uint8) 55 | 56 | # Run seamless_seg to calculate new central tile 57 | weights = seamless_seg.overlap_weights(central_geom, boxes) 58 | _, out = seamless_seg.apply_weights(central_tile, nearby_tiles, weights) 59 | 60 | # Printing output for inspection 61 | Image.fromarray(out.astype(np.uint8)).save(VIS_FOLDER / 'test_overlap_weights_regular.png') 62 | 63 | def show_overlap_weights_irregular(): 64 | np.random.seed(123459) 65 | # Creating test data 66 | # Irregular 67 | central_geom = shapely.box(200, 300, 500, 800) 68 | boxes = [ 69 | shapely.box(150, 150, 450, 650), 70 | shapely.box(300, 200, 600, 700), 71 | shapely.box(100, 500, 400, 1000), 72 | shapely.box(450, 650, 750, 1150), 73 | ] 74 | central_tile = np.ones((300, 500, 3), dtype=np.uint8) 75 | central_tile *= np.random.randint(20, 255, (3,), dtype=np.uint8) 76 | N = len(boxes) 77 | nearby_tiles = [] 78 | for i in range(N): 79 | nearby_tiles.append(np.ones((300, 500, 3), dtype=np.uint8)) 80 | nearby_tiles[i] *= np.random.randint(20, 255, (3,), dtype=np.uint8) 81 | 82 | # Run seamless_seg to calculate new central tile 83 | weights = seamless_seg.overlap_weights(central_geom, boxes) 84 | _, out = seamless_seg.apply_weights(central_tile, nearby_tiles, weights) 85 | 86 | # Printing output for inspection 87 | Image.fromarray(out.astype(np.uint8)).save(VIS_FOLDER / 'test_overlap_weights_irregular.png') 88 | 89 | 90 | def _random_tile_gen(shape, length=None): 91 | # Generates infinite fake tiles, where each tile is a single, randomly selected colour 92 | count = 0 93 | while True: 94 | tile = np.ones(shape, dtype=np.uint8) 95 | tile *= np.random.randint(20, 255, (shape[-1],), dtype=np.uint8) 96 | yield tile 97 | count += 1 98 | if length is not None and count >= length: 99 | break 100 | 101 | 102 | def test_coerce_grid_corrupt(): 103 | np.random.seed(2342352) 104 | image_size = (1024, 1024) 105 | image_shape = (*image_size, 3) 106 | tile_size = (48, 52) 107 | tile_shape = (*tile_size, 3) 108 | 109 | # Clean version, no corruption on grid, normal regular grid 110 | area = shapely.Polygon([[540, 125], [180, 690], [730, 565]]) 111 | grid = seamless_seg.regular_grid(image_size, tile_size, (10, 20), area) 112 | plan = seamless_seg.plan_from_grid(grid) 113 | ingeoms = seamless_seg.get_plan_logit_geoms(plan) 114 | in_tiles = [tile for tile in _random_tile_gen(tile_shape, len(ingeoms))] 115 | out_img_clean = np.zeros(image_shape) 116 | for index, out_geom, out_tile in seamless_seg.run_plan(plan, iter(in_tiles)): 117 | y_slc, x_slc = seamless_seg.shape_to_slices(out_geom) 118 | out_img_clean[y_slc, x_slc] = out_tile 119 | 120 | Image.fromarray(out_img_clean.astype(np.uint8)).save(VIS_FOLDER / 'clean_grid.png') 121 | 122 | # Corrupt grid: offset cells in the middle slightly 123 | grid_central = grid[1:-1, 1:-1] 124 | grid_np_coords = shapely.get_coordinates(grid_central) 125 | offset = np.random.randint(-1, 1, (grid_np_coords.shape[0]//5, 2)) 126 | for i in range(5): 127 | grid_np_coords[i::5] += offset 128 | shapely.set_coordinates(grid_central, grid_np_coords) 129 | grid[1:-1, 1:-1] = grid_central 130 | 131 | # Corrupt grid: flatten, randomise order 132 | grid_list = list(grid.flatten()) 133 | random.shuffle(grid_list) 134 | grid_np_flat = np.array(grid_list) 135 | 136 | # Now make it work with seamless_seg 137 | boundss = np.array([cell.bounds for cell in grid_np_flat if cell is not None]) 138 | coerced_grid, flat_to_grid_map = seamless_seg.coerce_to_grid(boundss) 139 | 140 | plan = seamless_seg.plan_from_grid(coerced_grid, margin=(3, 8)) 141 | out_img_corrupt = np.zeros(image_shape) 142 | for index, out_geom, out_tile in seamless_seg.run_plan(plan, iter(in_tiles)): 143 | y_slc, x_slc = seamless_seg.shape_to_slices(out_geom) 144 | out_img_corrupt[y_slc, x_slc] = out_tile 145 | 146 | Image.fromarray(out_img_corrupt.astype(np.uint8)).save(VIS_FOLDER / 'corrupt_grid.png') 147 | 148 | 149 | def random_colour_grid( 150 | image_size, 151 | tile_size, 152 | overlap, 153 | area=None, 154 | actually_run=False, 155 | fname='out_img.png' 156 | ): 157 | np.random.seed(123459) 158 | draw_area = area is not None 159 | 160 | print(f'For an image sized {image_size}, with tile {tile_size} and overlap {overlap}') 161 | 162 | start = time.perf_counter() 163 | plan, grid = seamless_seg.plan_regular_grid(image_size, tile_size, overlap, area) 164 | end = time.perf_counter() 165 | print(f'Planning takes: {end-start:4.2f}s') 166 | 167 | max_loaded, load_actions, write_actions = seamless_seg.analyse_plan(plan) 168 | print(f'Plan holds a maximum of {max_loaded} tiles in memory at once.') 169 | print(f'Plan loads {load_actions} tiles and writes {write_actions} tiles.') 170 | print(f'That is, plan holds {max_loaded/load_actions:4.1%} of tiles in memory') 171 | 172 | if not actually_run: 173 | print() 174 | return 175 | 176 | start = time.perf_counter() 177 | 178 | # Create fake random data 179 | in_tile_gen = _random_tile_gen((*tile_size, 3)) 180 | 181 | # Run plan 182 | out_img = np.zeros((*image_size, 3)) 183 | for index, out_geom, out_tile in seamless_seg.run_plan(plan, in_tile_gen): 184 | y_slc, x_slc = seamless_seg.shape_to_slices(out_geom) 185 | out_img[y_slc, x_slc] = out_tile 186 | 187 | end = time.perf_counter() 188 | print(f'Running plan takes: {end-start:4.2f}s') 189 | print() 190 | 191 | # Save out_img to disk 192 | if draw_area: 193 | coords = shapely.get_coordinates(area) 194 | rr, cc = skimage.draw.polygon(coords[:, 0], coords[:, 1]) 195 | out_img[rr, cc] = 255 196 | 197 | Image.fromarray(out_img.astype(np.uint8)).save(VIS_FOLDER / fname) 198 | 199 | def test_batched_colour_grid( 200 | image_size, 201 | tile_size, 202 | overlap, 203 | fname='out_img.png' 204 | ): 205 | random_tiles = _random_tile_gen((*tile_size, 3)) 206 | batch_size = 16 207 | def _get_tiles(indexs, geoms): 208 | return np.stack([next(random_tiles) for _ in geoms]) 209 | 210 | def _input_generator(plan): 211 | geoms = seamless_seg.get_plan_logit_geoms(plan) 212 | return seamless_seg.threaded_batched_tile_get(geoms, batch_size, _get_tiles, batch_size*3) 213 | 214 | # Iterate over output tiles; in this case, write directly to a np array 215 | # But in real use cases, you can write the tile to disk (e.g. rasterio/tifffile) 216 | plan, grid = seamless_seg.plan_regular_grid(image_size, tile_size, overlap) 217 | in_tiles = _input_generator(plan) 218 | out_img = np.zeros((*image_size, 3)) 219 | for index, out_geom, out_tile in seamless_seg.run_plan(plan, in_tiles): 220 | y_slc, x_slc = seamless_seg.shape_to_slices(out_geom) 221 | out_img[y_slc, x_slc] = out_tile 222 | 223 | # Save out_img to disk 224 | vis_folder = Path('vis') 225 | vis_folder.mkdir(exist_ok=True) 226 | Image.fromarray(out_img.astype(np.uint8)).save(vis_folder / fname) 227 | 228 | 229 | def random_colour_grid_visualise_cache(image_size, tile_size, overlap, do_print=False): 230 | plan, grid = seamless_seg.plan_regular_grid(image_size, tile_size, overlap) 231 | visualisation = np.zeros((*grid.shape[:2], 3), dtype=np.uint8) 232 | vis_cache_folder = Path('vis/cache-vis') 233 | vis_cache_folder.mkdir(exist_ok=True) 234 | i = 0 235 | 236 | def _on_load(index): 237 | nonlocal i 238 | if do_print: 239 | print(f'{i:04d}: loading {index}') 240 | visualisation[index[0], index[1], 0] = 255 241 | Image.fromarray(visualisation).save(vis_cache_folder / f'cache_{i:04d}.png') 242 | i += 1 243 | 244 | def _on_unload(index): 245 | nonlocal i 246 | if do_print: 247 | print(f'{i:04d}: unloading {index}') 248 | visualisation[index[0], index[1], 0] = 0 249 | Image.fromarray(visualisation).save(vis_cache_folder / f'cache_{i:04d}.png') 250 | i += 1 251 | 252 | def _on_disk_evict(index): 253 | nonlocal i 254 | if do_print: 255 | print(f'{i:04d}: evicting {index} to disk') 256 | visualisation[index[0], index[1], 1] = 255 257 | visualisation[index[0], index[1], 0] = 0 258 | Image.fromarray(visualisation).save(vis_cache_folder / f'cache_{i:04d}.png') 259 | i += 1 260 | 261 | def _on_disk_restore(index): 262 | nonlocal i 263 | if do_print: 264 | print(f'{i:04d}: restoring {index} from disk') 265 | visualisation[index[0], index[1], 1] = 0 266 | visualisation[index[0], index[1], 0] = 255 267 | Image.fromarray(visualisation).save(vis_cache_folder / f'cache_{i:04d}.png') 268 | i += 1 269 | 270 | def _on_step(n): 271 | # Image.fromarray(visualisation).save(vis_cache_folder / f'cache_{n:04d}.png') 272 | pass 273 | 274 | in_tile_gen = _random_tile_gen((*tile_size, 3)) 275 | out_tiles = seamless_seg.run_plan( 276 | plan, 277 | in_tile_gen, 278 | 10, 279 | disk_cache_dir=(VIS_FOLDER / 'data-cache'), 280 | on_load=_on_load, 281 | on_unload=_on_unload, 282 | on_disk_evict=_on_disk_evict, 283 | on_disk_restore=_on_disk_restore, 284 | on_step=_on_step, 285 | ) 286 | 287 | for index, out_geom, out_tile in out_tiles: 288 | if do_print: 289 | print(f'{i:04d}: writing {index}') 290 | visualisation[index[0], index[1], 2] = 255 291 | Image.fromarray(visualisation).save(vis_cache_folder / f'cache_{i:04d}.png') 292 | i += 1 293 | 294 | 295 | def main(): 296 | 297 | area = shapely.Polygon([ 298 | [26, 14], [5, 20], [19, 28], [5, 44], [15, 55], [22, 40], 299 | [38, 55], [44, 37], [26, 28], [44, 19], [40, 4], [17, 6] 300 | ]) 301 | 302 | 303 | VIS_FOLDER.mkdir(exist_ok=True) 304 | random_colour_grid((48, 64), (5, 5), (2, 2), actually_run=True, fname='small_grid.png') 305 | random_colour_grid((128, 86), (7, 7), (2, 2), area=area, actually_run=True, fname='small_grid_w_area.png') 306 | random_colour_grid((256, 256), (58, 84), (6, 12), actually_run=True, fname='mid_grid.png') 307 | # random_colour_grid((40000, 40000), (256, 256), (64, 64), actually_run=False) 308 | # random_colour_grid_visualise_cache((48, 64), (11,11), (2, 2)) 309 | test_batched_colour_grid((128, 86), (7, 7), (2, 2), fname='small_grid_batched.png') 310 | 311 | 312 | if __name__ == '__main__': 313 | main() 314 | -------------------------------------------------------------------------------- /seamless_seg.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import dataclasses 3 | import queue 4 | import threading 5 | import math 6 | from pathlib import Path 7 | from typing import Sequence, Iterable, Generator 8 | 9 | import numpy as np 10 | import scipy 11 | import shapely 12 | import shapely.affinity 13 | 14 | # Consistently arbitrarily ordered list of 8 directions to look for adjacent tiles 15 | GRID_DIR = np.array([(j, i) for j in (-1, 0, 1) for i in (-1, 0, 1) if not (i == j == 0)]) 16 | 17 | 18 | def shape_to_slices(shp: shapely.Geometry): 19 | ylo, xlo, yhi, xhi = shp.bounds 20 | ylo, xlo = round(ylo), round(xlo) 21 | yhi, xhi = round(yhi), round(xhi) 22 | return slice(ylo, yhi), slice(xlo, xhi) 23 | 24 | 25 | def mk_circle_of_trust(h, w): 26 | trust_coords_T = np.array([(-1, h // 2, h), (-1, w // 2, w)]) 27 | trust_values = [[0, 0, 0], [0, 1, 0], [0, 0, 0]] 28 | interpolator = scipy.interpolate.RegularGridInterpolator(trust_coords_T, trust_values) 29 | 30 | eval_coords = tuple(np.indices((h, w))) 31 | return interpolator(eval_coords) 32 | 33 | 34 | def get_trimmed_bounds(margin: tuple[int, int], dirs: Sequence[tuple[int, int]]): 35 | if margin is None: 36 | return 0, 0, None, None 37 | my, mx = margin 38 | ylo, xlo, yhi, xhi = 0, 0, None, None 39 | for j, i in dirs: 40 | if j == -1: 41 | ylo = my 42 | if j == 1: 43 | yhi = -my 44 | if i == -1: 45 | xlo = mx 46 | if i == 1: 47 | xhi = -mx 48 | return ylo, xlo, yhi, xhi 49 | 50 | 51 | def trim_array(arr: np.ndarray, bounds: tuple[int, int, int, int]): 52 | ylo, xlo, yhi, xhi = bounds 53 | return arr[..., ylo:yhi, xlo:xhi] 54 | 55 | 56 | def trim_box(shp: shapely.Geometry, bounds: tuple[int, int, int, int]): 57 | bylo, bxlo, byhi, bxhi = shp.bounds 58 | tylo, txlo, tyhi, txhi = bounds 59 | slices = (slice(tylo, tyhi), slice(txlo, txhi)) 60 | tyhi = 0 if tyhi is None else tyhi 61 | txhi = 0 if txhi is None else txhi 62 | new_box = shapely.box(bylo + tylo, bxlo + txlo, byhi + tyhi, bxhi + txhi) 63 | return new_box, slices 64 | 65 | 66 | def overlap_weights( 67 | central: shapely.Geometry, 68 | nearby: Sequence[shapely.Geometry], 69 | trim_bounds: tuple[int, int, int, int] = None, 70 | ) -> ( 71 | shapely.Geometry, 72 | np.ndarray, 73 | tuple[slice, slice], 74 | np.ndarray, 75 | list[tuple[tuple[slice, slice], tuple[slice, slice]]], 76 | ): 77 | """ 78 | Calculates everything needed to combine a central geometry with N nearby geometries. 79 | The nearby geometries need not be in a regular grid. They can be arbitrarily arranged. 80 | Invoking this does not depend on any real data. 81 | When trim_bounds is provided, it forces the output to be sliced to fit those bounds. 82 | 83 | For simple use cases, use in conjunction with seamless_seg.apply_weights. 84 | 85 | By default, overlap_weights describes the full area of the central geometry. 86 | Thus, using it once each on adjacent tiles describes the overlapping area between 87 | them twice (i.e. in each call). 88 | To account for this, provide a trim_bounds of half the overlapping area. 89 | 90 | e.g. Say we have two tiles 100 pixels wide next to each other, and they overlap 91 | 40 pixels with each other. 92 | ``` 93 | geom_a = shapely.box(0, 0, 100, 100) 94 | geom_b = shapely.box(0, 60, 100, 160) 95 | 96 | out_geom_a, _, _, _, _ = overlap_weights(geom_a, [geom_b]) 97 | out_geom_b, _, _, _, _ = overlap_weights(geom_b, [geom_a]) 98 | 99 | print(shapely.area(shapely.intersection(out_geom_a, out_geom_b))) 100 | # is 40*100 = 4000 101 | 102 | out_geom_a, _, _, _, _ = overlap_weights(geom_a, [geom_b], (0, 0, None, -20)) 103 | out_geom_b, _, _, _, _ = overlap_weights(geom_b, [geom_a], (0, 20, None, None)) 104 | 105 | print(shapely.area(shapely.intersection(out_geom_a, out_geom_b))) 106 | # is 0 because the overlapping region has been trimmed a bit on each side 107 | ``` 108 | 109 | Returns: 110 | out_geom: defines the space in the output to which the central_weights refers 111 | central_weights: how much to use the data from central per-pixel (0 to 1) 112 | centre_from_tile_slc: slices into tile defined by central to select out_geom 113 | nearby_weights: how much to use each of the nearby geometries 114 | slice_pairs: how to read from central_weights and nearby_weights for combining 115 | """ 116 | # Make circle of trust for the central geom 117 | ylo, xlo, yhi, xhi = central.bounds 118 | h, w = int(yhi - ylo), int(xhi - xlo) 119 | circle_of_trust = mk_circle_of_trust(h, w) 120 | 121 | # Make circles of trust for nearby geoms 122 | nearby_bounds = [n.bounds for n in nearby] 123 | nearby_shp = [(int(b[2] - b[0]), int(b[3] - b[1])) for b in nearby_bounds] 124 | nearby_circles_of_trust = np.stack([mk_circle_of_trust(nh, nw) for nh, nw in nearby_shp]) 125 | 126 | # Initialise trusts to be read from nearby geoms 127 | nearby_trusts = np.zeros((len(nearby), h, w)) 128 | 129 | # If we need to trim the bounds, we trim only the central geom and associated arrays 130 | if trim_bounds is not None: 131 | tylo, txlo, _, _ = trim_bounds 132 | ylo += tylo 133 | xlo += txlo 134 | circle_of_trust = trim_array(circle_of_trust, trim_bounds) 135 | nearby_trusts = trim_array(nearby_trusts, trim_bounds) 136 | central, centre_from_tile_slc = trim_box(central, trim_bounds) 137 | else: 138 | centre_from_tile_slc = (slice(None, None), slice(None, None)) 139 | 140 | # Calcuate nearby trusts and how to slice these trusts for each nearby geom 141 | overlaps = shapely.intersection(np.array([central]), np.array(nearby)) 142 | slice_pairs = [] 143 | for i, overlap in enumerate(overlaps): 144 | # Get slices into central and nearby 145 | oylo, oxlo, _, _ = nearby[i].bounds 146 | central_slices = shape_to_slices(shapely.affinity.translate(overlap, -ylo, -xlo)) 147 | nearby_slices = shape_to_slices(shapely.affinity.translate(overlap, -oylo, -oxlo)) 148 | slice_pairs.append((central_slices, nearby_slices)) 149 | 150 | # Write just for the overlapping parts 151 | i_c_slices = (i, *central_slices) 152 | i_n_slices = (i, *nearby_slices) 153 | nearby_trusts[i_c_slices] = nearby_circles_of_trust[i_n_slices] 154 | 155 | # Normalise pixel-wise 156 | total = np.concatenate([circle_of_trust[None], nearby_trusts], axis=0).sum(axis=0) 157 | central_weights = circle_of_trust / total 158 | nearby_weights = nearby_trusts / total 159 | 160 | return central, central_weights, centre_from_tile_slc, nearby_weights, slice_pairs 161 | 162 | 163 | def apply_weights(central_tile: np.ndarray, nearby_tiles: list[np.ndarray], weights): 164 | """ 165 | Apply overlap weights to real tile data. 166 | 167 | Example usage: 168 | ``` 169 | # Assuming we have: central_geom, nearby_geoms, central_tile, nearby_tiles 170 | weights = seamless_seg.overlap_weights(central_geom, nearby_geoms) 171 | out_geom, out_tile = seamless_seg.apply_weights(central_tile, nearby_tiles, weights) 172 | ``` 173 | """ 174 | out_geom, central_weights, centre_from_tile_slc, nearby_weights, slice_pairs = weights 175 | out_tile = central_tile[centre_from_tile_slc] * central_weights[..., None] 176 | z = enumerate(zip(nearby_weights, slice_pairs)) 177 | for i, (nearby_weight, (central_slices, nearby_slices)) in z: 178 | vals = nearby_tiles[i][nearby_slices] 179 | val_weights = nearby_weight[central_slices][..., None] 180 | out_tile[central_slices] += vals * val_weights 181 | return out_geom, out_tile 182 | 183 | 184 | def mk_box_grid( 185 | width, height, x_offset=0, y_offset=0, box_width=1, box_height=1, overlap_x=0, overlap_y=0 186 | ): 187 | """ 188 | Create a grid of box geometries, stored in a vectorised Shapely array. 189 | """ 190 | gap_width = box_width - overlap_x 191 | gap_height = box_height - overlap_y 192 | xs = np.arange((width - overlap_x) // gap_width) * gap_width 193 | ys = np.arange((height - overlap_y) // gap_height) * gap_height 194 | yss, xss = np.meshgrid(ys, xs) 195 | # fmt: off 196 | coords = np.array([ # Clockwise squares 197 | [xss+x_offset, yss+y_offset], 198 | [xss+x_offset+box_width, yss+y_offset], 199 | [xss+x_offset+box_width, yss+y_offset+box_height], 200 | [xss+x_offset, yss+y_offset+box_height], 201 | ]).transpose((2,3,0,1)) # shapes [4, 2, W, H] -> [W, H, 4, 2] 202 | # fmt: on 203 | return shapely.polygons(coords) 204 | 205 | 206 | def calc_gridcell_needed(grid_mask): 207 | # Calculate which grid cells are needed to calculate grid cells that are in grid_mask 208 | any_masks = [grid_mask] 209 | 210 | # For each direction, grab an offset grid_mask, indicating which cells are needed due 211 | # to there being a needed grid cell in that direction 212 | def _dir_to_slice(v): 213 | if v == -1: 214 | return slice(None, -1), slice(1, None) 215 | elif v == 1: 216 | return slice(1, None), slice(None, -1) 217 | else: 218 | return slice(None), slice(None) 219 | 220 | for j, i in GRID_DIR: 221 | orig_y_slc, out_y_slc = _dir_to_slice(j) 222 | orig_x_slc, out_x_slc = _dir_to_slice(i) 223 | mask = np.zeros_like(grid_mask, dtype=bool) 224 | mask[out_y_slc, out_x_slc] = grid_mask[orig_y_slc, orig_x_slc] 225 | any_masks.append(mask) 226 | return np.any(any_masks, axis=0) 227 | 228 | 229 | def row_by_row_traversal(grid, add_load, add_unload, add_write): 230 | """ 231 | Traverses a grid, deciding when to load/unload/write tiles. 232 | The responsibility of this function is to ensure that for every write action marked, 233 | at that point in the plan, all nearby tiles would be loaded into the cache. 234 | It is not the responsibility of this function to determine if any such tile is in bounds. 235 | 236 | This traverses row-by-row, keeping two full rows of tiles in the cache at once. 237 | This will ensure that no tile is read more than once and has a significantly smaller 238 | memory requirement than keeping all tiles in memory at once. 239 | This may not be optimal in all cases. 240 | """ 241 | gh, gw = grid.shape[:2] 242 | if gh >= gw: 243 | for gx in range(gw): 244 | add_load(0, gx) 245 | for gy in range(gh): 246 | # Visualising what is in cache: 247 | # ("|" means the tile is loaded, "." means the tile is not) 248 | # The cache should look like this for the row 249 | # gy-1: |||||||| 250 | # gy: |||||||| 251 | # gy+1: ........ 252 | add_load(gy + 1, 0) 253 | # gy-1: |||||||| 254 | # gy: |||||||| 255 | # gy+1: |....... 256 | for gx in range(gw): 257 | # ||| 258 | # ||| 259 | # ||. 260 | add_load(gy + 1, gx + 1) 261 | add_write(gy, gx) 262 | add_unload(gy - 1, gx - 1) 263 | # .|| 264 | # ||| 265 | # ||| 266 | # gy-1: .......| 267 | # gy: |||||||| 268 | # gy+1: |||||||| 269 | add_unload(gy - 1, gw - 1) 270 | # gy-1: ........ 271 | # gy: |||||||| 272 | # gy+1: |||||||| 273 | for gx in range(gw): 274 | add_unload(gh - 1, gx) 275 | else: 276 | # As above, but transposed 277 | for gy in range(gh): 278 | add_load(gy, 0) 279 | for gx in range(gw): 280 | add_load(0, gx + 1) 281 | for gy in range(gh): 282 | add_load(gy + 1, gx + 1) 283 | add_write(gy, gx) 284 | add_unload(gy - 1, gx - 1) 285 | add_unload(gh - 1, gx - 1) 286 | for gy in range(gh): 287 | add_unload(gy, gw - 1) 288 | 289 | 290 | def _mk_angle_to_dir_fnc(bounds: tuple[int, int, int, int]): 291 | ylo, xlo, yhi, xhi = bounds 292 | ydif, xdif = (yhi - ylo), (xhi - xlo) 293 | diag_angle = math.atan(ydif / xdif) 294 | angle_to_dir = { 295 | math.pi * 0 / 4: (0, 1), 296 | math.pi * 0 / 4 + diag_angle: (1, 1), 297 | math.pi * 2 / 4: (1, 0), 298 | math.pi * 4 / 4 - diag_angle: (1, -1), 299 | math.pi * 4 / 4: (0, -1), 300 | -math.pi * 4 / 4 + diag_angle: (-1, -1), 301 | -math.pi * 2 / 4: (-1, 0), 302 | -math.pi * 0 / 4 - diag_angle: (-1, 1), 303 | } 304 | key_angles = np.array(list(angle_to_dir.keys())) 305 | 306 | def _calc_dir(ydif, xdif): 307 | angle = math.atan2(ydif, xdif) 308 | adif = np.abs(key_angles - angle) % (2 * math.pi) 309 | min_angle = adif.argmin() 310 | return angle_to_dir[key_angles[min_angle]] 311 | 312 | return _calc_dir 313 | 314 | 315 | def coerce_to_grid(boundss: np.ndarray) -> tuple[np.ndarray, list[tuple[int, int]]]: 316 | """ 317 | Algorithm to coerce a flat list of geometry bounds into a 2D geometry grid. 318 | Not well-optimised. 319 | 320 | Assumptions: 321 | * scanning by overlapping bounds will discover all boundss 322 | * boundss are all the same size 323 | 324 | Returns: 325 | grid: np.ndarray 326 | 2D grid of shapely geometries shaped [H, W] 327 | mapping: list[tuple[int, int]] 328 | parallel to input flat list, where each geometry ended up in grid 329 | """ 330 | # Ensure order is top-left to bottom-right 331 | boundss = sorted(boundss.tolist()) 332 | boundss = np.asarray(boundss) 333 | 334 | # Get all overlaps 335 | geoms = np.asarray([shapely.box(*b) for b in boundss]) # shaped [N, 4, 2] 336 | overlaps = shapely.intersects(geoms[:, None], geoms[None]) 337 | 338 | # Define how to identify directions 339 | _calc_dir = _mk_angle_to_dir_fnc(boundss[0]) 340 | 341 | # Start from the first box in boundss. Breadth-first search through the boxes. 342 | # Use overlap to identify adjacent boxes. 343 | # Assign a y/x coord to each discovered box. 344 | # Add each found box and y/x coord to grid_list. 345 | open_list = [(0, 0, 0)] 346 | closed_list = [0] 347 | grid_list = [] 348 | mapped = {0: (0, 0)} 349 | closed_set = {(0, 0)} 350 | while len(open_list) > 0: 351 | i, y, x = open_list.pop(0) 352 | grid_list.append((i, y, x)) 353 | iylo, ixlo, iyhi, ixhi = boundss[i] 354 | icy, icx = (iylo + iyhi) / 2, (ixlo + ixhi) / 2 355 | dists = collections.defaultdict(lambda: []) 356 | for j in overlaps[i].nonzero()[0]: 357 | if i == j: 358 | continue 359 | if j not in closed_list: 360 | jylo, jxlo, jyhi, jxhi = boundss[j] 361 | jcy, jcx = (jylo + jyhi) / 2, (jxlo + jxhi) / 2 362 | dy, dx = jcy - icy, jcx - icx 363 | ymod, xmod = _calc_dir(dy, dx) 364 | if (y + ymod, x + xmod) not in closed_set: 365 | dists[(ymod, xmod)].append((j.item(), np.linalg.norm((dy, dx)).item())) 366 | for (ymod, xmod), distlist in dists.items(): 367 | d_np = np.array(distlist) 368 | j = round(d_np[np.argmin(d_np[:, 1])][0].item()) 369 | open_list.append((j, y + ymod, x + xmod)) 370 | mapped[j] = [y + ymod, x + xmod] 371 | closed_list.append(j) 372 | closed_set.add((y + ymod, x + xmod)) 373 | 374 | # Create a 2D grid of coordinates, and populate with boxes found in search 375 | grid_list = np.array(grid_list) 376 | ymin = grid_list[:, 1].min() 377 | xmin = grid_list[:, 2].min() 378 | ymax = grid_list[:, 1].max() 379 | xmax = grid_list[:, 2].max() 380 | grid = shapely.empty((ymax - ymin + 1, xmax - xmin + 1)) 381 | 382 | for i, y, x in grid_list: 383 | grid[y - ymin, x - xmin] = shapely.box(*boundss[i]) 384 | 385 | mapping = [ 386 | (round(mapped[j][0] + ymin), round(mapped[j][1] + xmin)) for j in range(len(boundss)) 387 | ] 388 | 389 | return grid, mapping 390 | 391 | 392 | def regular_grid( 393 | image_size: tuple[int, int], 394 | tile_size: tuple[int, int], 395 | overlap: tuple[int, int], 396 | area: shapely.Geometry = None, 397 | ) -> np.ndarray[shapely.Geometry]: 398 | # Unpack sizes 399 | ih, iw = image_size 400 | th, tw = tile_size 401 | if area is None: 402 | area = shapely.box(0, 0, ih, iw) 403 | 404 | ylo, xlo, yhi, xhi = area.bounds 405 | 406 | # If the area is smaller than the image, then we want to include tiles 407 | # just outside the area so we can blend into the area properly 408 | gpylo = max(0, ylo - th) 409 | gpxlo = max(0, xlo - tw) 410 | gpyhi = min(ih, yhi + th) 411 | gpxhi = min(iw, xhi + tw) 412 | 413 | # Make an initial regular grid 414 | gph, gpw = gpyhi - gpylo, gpxhi - gpxlo 415 | grid = mk_box_grid(gph, gpw, gpylo, gpxlo, th, tw, *overlap) 416 | # If the grid doesn't cover the area perfectly (very likely), 417 | # add another layer of boxes along the edges 418 | gbyhi, gbxhi = grid[-1, -1].bounds[-2:] 419 | if gbyhi < yhi: 420 | # Create a new strip of boxes by copying the last one and then offsetting it such 421 | # that it is flush with the area boundary. 422 | gap = int(yhi - gbyhi) 423 | grid_strip = np.array([shapely.affinity.translate(cell, gap, 0) for cell in grid[-1, :]]) 424 | grid = np.concatenate([grid, grid_strip[None]], axis=0) 425 | if gbxhi < xhi: 426 | gap = int(xhi - gbxhi) 427 | grid_strip = np.array([shapely.affinity.translate(cell, 0, gap) for cell in grid[:, -1]]) 428 | grid = np.concatenate([grid, grid_strip[:, None]], axis=1) 429 | 430 | # Remove grid cells outside area 431 | mask = shapely.intersects(grid, area) 432 | grid[~mask] = None 433 | 434 | return grid 435 | 436 | 437 | def _mk_cache_hash(geom, dir_mask, nearby): 438 | # Assuming tiles are always the same size, then 439 | gylo, gxlo, _, _ = geom.bounds 440 | ylos = np.asarray([gylo] + [shp.bounds[0] for shp in nearby]) 441 | xlos = np.asarray([gxlo] + [shp.bounds[1] for shp in nearby]) 442 | return dir_mask.sum().item(), ylos.mean() - gylo, xlos.mean() - gxlo 443 | 444 | 445 | @dataclasses.dataclass 446 | class Step: 447 | action: str 448 | index: tuple[int, int] # grid index (can be used as cache key) 449 | 450 | 451 | @dataclasses.dataclass 452 | class LoadStep(Step): 453 | geom: shapely.Geometry # geometry to load 454 | 455 | 456 | @dataclasses.dataclass 457 | class WriteStep(Step): 458 | geom: shapely.Geometry # reference central geometry 459 | nearby: Sequence[tuple[int, int]] # indexes of geoms defined as nearby 460 | weight: tuple # outputs of overlap_weights 461 | 462 | 463 | def plan_from_grid( 464 | grid: np.ndarray[shapely.Geometry], 465 | margin: tuple[int, int] = None, 466 | area: shapely.Geometry = None, 467 | traversal_fnc: callable = row_by_row_traversal, 468 | ) -> list[Step]: 469 | """ 470 | Create a plan for running on a somewhat arbitrary grid. 471 | 472 | There is a restriction/assumption that must be satisfied: 473 | For each geometry at grid[y, x] the only geoms which overlap a tile are within +-1 474 | e.g. for grid[5, 5], the only geoms which overlap it are in the range grid[4:7, 4:7] 475 | 476 | Works for "grids" that aren't perfectly regular: 477 | * can have small offsets (assuming offsets are smaller than (overlap - margin)) 478 | 479 | IMPORTANT: All inputs should be YX, not XY. 480 | 481 | `margin` if provided, will subtract a margin along overlapping edges of each tile; 482 | if not provided, this means that overlapping areas will be written multiple times; 483 | if grid is regular, should be exactly half the overlap between tiles; 484 | if grid is irregular, large values might lead to holes in output. 485 | `area` can be any arbitrary geometry (i.e. need not be a rectangle) 486 | `traversal_fnc` lets you define a custom grid traversal algorithm, a callable with: 487 | traversal_fnc(grid, add_load_step, add_unload_step, add_write_step) 488 | Which decides when to load which tiles, when to unload them, and when to write them. 489 | Doesn't need to worry about whether those grid tiles are actually possible or not. 490 | 491 | Returns: 492 | plan (list[Step]): Describes how to manage the cache, and when/how to write tiles. 493 | Steps can be load, unload or write. 494 | """ 495 | if area is None: 496 | area = shapely.unary_union(grid) 497 | _, _, gyhi, gxhi = area.bounds 498 | 499 | # Determine grid boundaries and which cells are possible 500 | gh, gw = grid.shape[:2] 501 | grid_in_area = shapely.intersects(grid, area) 502 | gridcell_needed = calc_gridcell_needed(grid_in_area) 503 | 504 | plan = [] 505 | weight_cache = {} 506 | 507 | # By pushing these to helper functions we separate the traversal logic from 508 | # deciding to load/unload/write only for tiles that need it (based on provided area) 509 | def _in_bounds(gy, gx): 510 | return 0 <= gy < gh and 0 <= gx < gw and grid[gy, gx] is not None 511 | 512 | def _add_load_step(gy, gx): 513 | if _in_bounds(gy, gx) and gridcell_needed[gy, gx]: 514 | plan.append(LoadStep(action="load", index=(gy, gx), geom=grid[gy, gx])) 515 | 516 | def _add_unload_step(gy, gx): 517 | if _in_bounds(gy, gx) and gridcell_needed[gy, gx]: 518 | plan.append(Step(action="unload", index=(gy, gx))) 519 | 520 | def _calc_weight(gy, gx, geom, dir_mask): 521 | # Check which directions are within the grid 522 | nearby = [(int(gy + j), int(gx + i)) for j, i in GRID_DIR[dir_mask]] 523 | nearby_geom = np.array([grid[y, x] for y, x in nearby]) 524 | # Based on which directions have a tile, determine how to trim the output 525 | trim_bounds = get_trimmed_bounds(margin, GRID_DIR[dir_mask]) 526 | 527 | # Only create new weights if we have to 528 | cache_hash = _mk_cache_hash(geom, dir_mask, nearby_geom) 529 | if cache_hash in weight_cache: 530 | # All but one of the weights are relative. The absolute output is the out_geom. 531 | # So, here we account for a different input geom after-the-fact. 532 | (out_geom, a, b, c, d), other_geom = weight_cache[cache_hash] 533 | oylo, oxlo, _, _ = other_geom.bounds 534 | tylo, txlo, _, _ = geom.bounds 535 | out_geom = shapely.affinity.translate(out_geom, tylo - oylo, txlo - oxlo) 536 | return (out_geom, a, b, c, d), nearby 537 | 538 | # Finally calculate the weights for combining this tile with its nearby. 539 | weight = overlap_weights(geom, nearby_geom, trim_bounds) 540 | weight_cache[cache_hash] = (weight, geom) 541 | return weight, nearby 542 | 543 | def _add_write_step(gy, gx): 544 | if grid_in_area[gy, gx]: 545 | geom = grid[gy, gx] 546 | dir_mask = np.asarray([_in_bounds(gy + j, gx + i) for j, i in GRID_DIR]) 547 | weight, nearby = _calc_weight(gy, gx, geom, dir_mask) 548 | base = {"geom": geom, "index": (gy, gx), "weight": weight} 549 | plan.append(WriteStep(action="write", **base, nearby=nearby)) 550 | 551 | traversal_fnc(grid, _add_load_step, _add_unload_step, _add_write_step) 552 | 553 | return plan 554 | 555 | 556 | def plan_regular_grid( 557 | image_size: tuple[int, int], 558 | tile_size: tuple[int, int], 559 | overlap: tuple[int, int], 560 | area: shapely.Geometry = None, 561 | traversal_fnc: callable = row_by_row_traversal, 562 | ) -> tuple[list[Step], np.ndarray[shapely.Geometry]]: 563 | """ 564 | Plans out running segmentation over a single large image by tiling, overlapping 565 | and blending between adjacent tiles in a regular grid. 566 | 567 | IMPORTANT: All inputs should be YX, not XY. 568 | 569 | Does not depend on any real data; merely creates a geometry plan based on size data. 570 | 571 | `area` can be any arbitrary geometry (i.e. need not be a rectangle) 572 | `traversal_fnc` lets you define a custom grid traversal algorithm, a callable with: 573 | traversal_fnc(grid, add_load_step, add_unload_step, add_write_step) 574 | Which decides when to load which tiles, when to unload them, and when to write them. 575 | Doesn't need to worry about whether those grid tiles are actually possible or not. 576 | 577 | Returns: 578 | plan (list[Step]): Describes how to manage the cache, and when/how to write tiles. 579 | Steps can be load, unload or write. 580 | grid (np.ndarray[shapely.Geometry]): shaped [H, W], a grid of geometries describing 581 | where each tile is placed within the image. 582 | """ 583 | oh, ow = overlap 584 | if not (oh % 2 == 0 or ow % 2 == 0): 585 | raise ValueError("Overlap must be an even number") 586 | margin = oh // 2, ow // 2 587 | grid = regular_grid(image_size, tile_size, overlap, area) 588 | return plan_from_grid(grid, margin, area, traversal_fnc), grid 589 | 590 | 591 | def batched_tile_get( 592 | geoms: list[tuple[tuple[int, int], shapely.Geometry]], 593 | batch_size: int, 594 | get_tiles_fnc: callable, 595 | ): 596 | """ 597 | Takes some function to get tiles `get_tiles_fnc` which is to expect a batch of geoms at once. 598 | Yields individual tiles 599 | """ 600 | batch_indices = [] 601 | batch_geoms = [] 602 | for index, geom in geoms: 603 | batch_indices.append(index) 604 | batch_geoms.append(geom) 605 | if len(batch_geoms) == batch_size: 606 | tiles = get_tiles_fnc(batch_indices, batch_geoms) 607 | for past_index, tile in zip(batch_indices, tiles): 608 | yield tile 609 | batch_indices = [] 610 | batch_geoms = [] 611 | tiles = get_tiles_fnc(batch_indices, batch_geoms) 612 | for past_index, tile in zip(batch_indices, tiles): 613 | yield tile 614 | 615 | 616 | def threaded_batched_tile_get( 617 | geoms: list[tuple[tuple[int, int], shapely.Geometry]], 618 | batch_size: int, 619 | get_tiles_fnc: callable, 620 | max_prefetched: int, 621 | ) -> Generator[tuple[tuple[int, int], np.ndarray], None, None]: 622 | """ 623 | Takes some function to get tiles `get_tiles_fnc` which is to expect a batch of geoms at once. 624 | Executes that function in a thread, prefetching those tiles before they are needed. 625 | Yields individual tiles 626 | """ 627 | out_queue = queue.Queue(max_prefetched) 628 | 629 | def _wrap_queue(): 630 | for tile in batched_tile_get(geoms, batch_size, get_tiles_fnc): 631 | out_queue.put(tile) 632 | 633 | thread = threading.Thread(target=_wrap_queue) 634 | thread.start() 635 | for _ in geoms: 636 | yield out_queue.get() 637 | 638 | 639 | def analyse_plan(plan: list[Step]) -> tuple[int, int, int]: 640 | """Counts maximum tiles loaded at once, total tiles loaded, and total write calls.""" 641 | loaded = 0 642 | total_loaded = 0 643 | max_loaded = 0 644 | write = 0 645 | for step in plan: 646 | if step.action == "load": 647 | loaded += 1 648 | total_loaded += 1 649 | elif step.action == "unload": 650 | loaded -= 1 651 | if loaded > max_loaded: 652 | max_loaded = loaded 653 | if step.action == "write": 654 | write += 1 655 | return max_loaded, total_loaded, write 656 | 657 | 658 | def get_plan_logit_geoms(plan): 659 | return [(step.index, step.geom) for step in plan if step.action == "load"] 660 | 661 | 662 | def simple_logit_generator(plan, get_logits): 663 | for index, geom in seamless_seg.get_plan_logit_geoms(plan): 664 | yield get_logits(geom) 665 | 666 | 667 | def _check_plan_doesnt_exceed(plan, max_tiles): 668 | if max_tiles is None: 669 | # No maximum set 670 | return 671 | 672 | max_loaded, _, _ = analyse_plan(plan) 673 | if max_loaded > max_tiles: 674 | raise Exception("Traversal method in plan would hold more than max tiles in memory") 675 | 676 | 677 | def noop(*args, **kwargs): 678 | pass 679 | 680 | 681 | def serialise_index(index): 682 | return f"{index[0]}-{index[1]}.npy" 683 | 684 | 685 | def run_plan( 686 | plan: list[Step], 687 | tiles: Iterable, 688 | max_tiles: int = None, 689 | disk_cache_dir: Path = None, 690 | on_load: callable = noop, 691 | on_unload: callable = noop, 692 | on_step: callable = noop, 693 | on_disk_evict: callable = noop, 694 | on_disk_restore: callable = noop, 695 | ) -> Generator[tuple[tuple[int, int], shapely.Geometry, np.ndarray], None, None]: 696 | """ 697 | Executes a previously created plan to read model logits, and blend them together seamlessly. 698 | 699 | Yields output geometries and tiles. 700 | 701 | The on_* hooks are provided indexes into the grid used to generate the plan. 702 | 703 | Args: 704 | plan (list[Step]): 705 | List of steps to execute. 706 | tiles (Iterable[np.ndarray]): Iterable of tiles containing model logits. 707 | Order must be as specified by seamless_seg.get_plan_logit_geoms 708 | max_tiles (int): 709 | Maximum number of tiles to keep in memory at onces. 710 | disk_cache_dir (Path): 711 | If plan would load more than `max_tiles`; stores them to disk in this directory. 712 | on_load (callable[tuple[int, int]->None]): 713 | Called after a new tile is loaded into memory. 714 | on_unload (callable[tuple[int, int]->None]): 715 | Called after a tile is removed from memory. 716 | on_step (callable[int->None]): 717 | Called after each Step is executed. Is given step number, not grid index. 718 | on_disk_evict (callable(tuple[int, int]->None)): 719 | Called when a tile is stored to disk cache. 720 | on_disk_restore (callable(tuple[int, int]->None)): 721 | Called when a tile is restored from disk cache. 722 | 723 | Yields: 724 | index: tuple[int, int], out_geom: shapely.Geometry, out_tile: np.ndarray 725 | """ 726 | cache = collections.OrderedDict() 727 | disk_cache = {} 728 | 729 | if max_tiles is not None and max_tiles <= 8: 730 | raise ValueError("If provided, max_tiles must be greater than 8") 731 | 732 | if disk_cache_dir is None: 733 | _check_plan_doesnt_exceed(plan, max_tiles) 734 | else: 735 | if max_tiles is None: 736 | raise ValueError("If disk_cache_dir is set, then max_tiles should be set") 737 | disk_cache_dir.mkdir(exist_ok=True, parents=True) 738 | 739 | # Two-level cache management functions; evicting to disk and restoring from disk. 740 | def _evict_oldest(): 741 | oldest_index, oldest_tile = cache.popitem(False) 742 | on_disk_evict(oldest_index) 743 | fpath = disk_cache_dir / serialise_index(oldest_index) 744 | np.save(fpath, oldest_tile) 745 | disk_cache[oldest_index] = fpath 746 | 747 | def _resolve_restore(index): 748 | if index in cache: 749 | cache.move_to_end(index) 750 | return cache[index] 751 | if len(cache) == max_tiles: 752 | _evict_oldest() 753 | cache[index] = np.load(disk_cache[index]) 754 | on_disk_restore(index) 755 | del disk_cache[index] 756 | return cache[index] 757 | 758 | # Run plan 759 | for n, step in enumerate(plan): 760 | if step.action == "load": 761 | # Put tile into cache 762 | if disk_cache_dir is not None: 763 | if len(cache) == max_tiles: 764 | _evict_oldest() 765 | cache[step.index] = next(tiles) 766 | on_load(step.index) 767 | elif step.action == "unload": 768 | # Remove tile from cache 769 | del cache[step.index] 770 | on_unload(step.index) 771 | elif step.action == "write": 772 | # Collect nearby tiles 773 | nearby_tiles = [] 774 | for index in step.nearby: 775 | if disk_cache_dir is None: 776 | tile = cache[index] 777 | else: 778 | tile = _resolve_restore(index) 779 | nearby_tiles.append(tile) 780 | 781 | # Collect central tile 782 | if disk_cache_dir is None: 783 | central_tile = cache[step.index] 784 | else: 785 | central_tile = _resolve_restore(step.index) 786 | 787 | # Apply weights from plan to create final output tile 788 | out_geom, out_tile = apply_weights(central_tile, nearby_tiles, step.weight) 789 | yield step.index, out_geom, out_tile 790 | else: 791 | raise Exception("Unknown plan action") 792 | on_step(n) 793 | 794 | 795 | def pytorch_outputs_generator(plan, model, read_tile, batch_size: int = None, device: str = None): 796 | import torch 797 | 798 | if device is None: 799 | if isinstance(model, torch.nn.Module): 800 | device = next(model.parameters()).device 801 | elif getattr(model, "device") is not None: 802 | device = getattr(model, "device") 803 | else: 804 | device = "cpu" 805 | else: 806 | device = device 807 | 808 | if batch_size is not None and batch_size >= 1: 809 | 810 | def _run_tiles(_, geoms): 811 | """A function which takes a batch of geoms and returns model outputs for those geoms""" 812 | # Load all images for batch 813 | imgs = [read_tile(in_geom) for in_geom in geoms] 814 | 815 | # Push batch through model 816 | img_th = torch.as_tensor(np.stack(imgs)).to(device) 817 | out_th = model(img_th) 818 | out = out_th.detach().cpu().numpy() 819 | 820 | # model output is in BCHW, yield model outputs in BHWC 821 | return out.transpose((0, 2, 3, 1)) 822 | 823 | def _input_generator(plan): 824 | geoms = get_plan_logit_geoms(plan) 825 | return threaded_batched_tile_get(geoms, batch_size, _run_tiles, batch_size * 3) 826 | 827 | else: 828 | 829 | def _input_generator(plan): 830 | for index, in_geom in get_plan_logit_geoms(plan): 831 | # Read image data 832 | img = read_tile(in_geom) 833 | 834 | # Push image data through model (don't forget batch dimension) 835 | img_th = torch.as_tensor(img[None]).to(device) 836 | out_th = model(img_th) 837 | out = out_th[0].detach().cpu().numpy() 838 | 839 | # Yield model outputs in HWC 840 | yield out.transpose((1, 2, 0)) 841 | 842 | return _input_generator(plan) 843 | 844 | 845 | def run_plan_pytorch( 846 | plan: list[Step], 847 | model: callable, 848 | read_tile: callable, 849 | write_tile: callable, 850 | batch_size: int = None, 851 | max_tiles: int = None, 852 | disk_cache_dir: Path = None, 853 | device: str = None, 854 | ): 855 | in_tiles = pytorch_outputs_generator(plan, model, read_tile, batch_size, device) 856 | out_tiles = run_plan(plan, in_tiles, max_tiles=max_tiles, disk_cache_dir=disk_cache_dir) 857 | for index, out_geom, out_tile in out_tiles: 858 | write_tile(out_geom, out_tile) 859 | 860 | 861 | def pytorch_rasterio( 862 | model: callable, 863 | in_tif, # rasterio.Dataset 864 | out_fname: str, 865 | tile_size: tuple[int, int], 866 | overlap: tuple[int, int] = None, 867 | batch_size: int = None, 868 | area: shapely.Geometry = None, 869 | area_in_crs: bool = True, 870 | max_tiles: int = None, 871 | disk_cache_dir: Path = None, 872 | device: str = None, 873 | ): 874 | """ 875 | Create a seamless segmentation in `out_tif`. 876 | Takes image data from `in_tif`, runs it through `model` to produce logits, 877 | uses seamless_seg to create segmentation and writes to `out_tif`. 878 | 879 | Args: 880 | in_tif: rasterio.Dataset 881 | out_fname: str 882 | Should be uint8 type for segmentation 883 | tile_size: int | tuple[int, int] 884 | Size of input to model 885 | model: callable[torch.Tensor -> torch.Tensor] 886 | Takes batch of image data, returns logits for the same shape 887 | batch_size: int, Optional 888 | If provided and greater than 1, runs model in batches of this size 889 | overlap: int | tuple[int, int], Optional 890 | Pixel overlap between tiles; larger overlap causes more gradual change, but is more expensive. 891 | Optional: default is half maximum to balance speed and performance. 892 | area: shapely.Geometry, Optional 893 | Only run the model on a subset of the in_tif 894 | area_in_crs: bool, Optional 895 | If True (default) assumes `area` is in CRS of `in_tif`. 896 | If False assumes `area` is in pixels. 897 | max_tiles: int, Optional 898 | To control memory footprint, you can set a maximum number of tiles to load at once. 899 | disk_cache_dir: Path, Optional 900 | When used in conjunction with max_tiles, will cache logits to disk during computation. 901 | device: str, Optional 902 | If provided, puts tiles onto device. Else attempts to read device from model. Else crashes. 903 | 904 | """ 905 | import rasterio 906 | 907 | profile = { 908 | **in_tif.profile, 909 | "dtype": np.uint8, 910 | "count": 1, 911 | "PHOTOMETRIC": "MINISBLACK", 912 | "COMPRESS": "PACKBITS", 913 | } 914 | with rasterio.open(out_fname, "w", **profile) as out_tif: 915 | 916 | if isinstance(tile_size, int): 917 | tile_size = (tile_size,) * 2 918 | if isinstance(overlap, int): 919 | overlap = (overlap,) * 2 920 | 921 | def read_tile(shp): 922 | img = in_tif.read(window=shape_to_slices(shp)) 923 | return img 924 | 925 | def write_tile(shp, tile): 926 | # Convert logits to segmentation mask 927 | seg = tile.argmax(axis=-1)[None] 928 | # Write segmentation mask to disk 929 | out_tif.write(seg, window=shape_to_slices(shp)) 930 | 931 | if overlap is None: 932 | overlap = tile_size[0] // 4, tile_size[1] // 4 933 | if area is not None and area_in_crs: 934 | coords = shapely.get_coordinates(area) 935 | in_tif.transform.itransform(coords) 936 | area = shapely.set_coordinates(area, coords) 937 | 938 | plan, grid = plan_regular_grid(in_tif.shape, tile_size, overlap, area=area) 939 | run_plan_pytorch( 940 | plan, model, read_tile, write_tile, batch_size, max_tiles, disk_cache_dir, device 941 | ) 942 | 943 | 944 | def pytorch_numpy( 945 | model: callable, 946 | img: np.ndarray, 947 | tile_size: int | tuple[int, int], 948 | overlap: int | tuple[int, int] = None, 949 | batch_size: int = None, 950 | max_tiles: int = None, 951 | disk_cache_dir: Path = None, 952 | device: str = None, 953 | ): 954 | """ 955 | Create a seamless segmentation of `img` using `model`. 956 | Takes tiles from `img`, runs it through `model` to produce logits, and 957 | uses seamless_seg to create segmentation, returning the img array. 958 | 959 | Args: 960 | model: callable[torch.Tensor -> torch.Tensor] 961 | Takes batch of image data, returns logits for the same shape 962 | img: np.ndarray 963 | Shaped [C, H, W] 964 | tile_size: int | tuple[int, int] 965 | Size of input to model (H, W) 966 | batch_size: int 967 | If provided and greater than 1, runs model in batches of this size 968 | overlap: int | tuple[int, int] 969 | Pixel overlap between tiles; larger overlap causes more gradual change, but is more expensive. 970 | Optional: default is half maximum to balance speed and performance. 971 | area: shapely.Geometry 972 | Only run the model on a subset of the in_tif 973 | area_in_crs: bool 974 | If True (default) assumes `area` is in CRS of `in_tif`. 975 | If False assumes `area` is in pixels. 976 | max_tiles: int 977 | To control memory footprint, you can set a maximum number of tiles to load at once. 978 | disk_cache_dir: Path 979 | When used in conjunction with max_tiles, will cache logits to disk during computation. 980 | device: str, Optional 981 | If provided, processes tiles on device. Else attempts to read device from model. Else crashes. 982 | 983 | """ 984 | out = np.zeros(img.shape[1:], dtype=np.int32) 985 | 986 | if isinstance(tile_size, int): 987 | tile_size = (tile_size,) * 2 988 | if isinstance(overlap, int): 989 | overlap = (overlap,) * 2 990 | 991 | def read_tile(shp): 992 | full_slice = (slice(None), *shape_to_slices(shp)) 993 | return img[full_slice] 994 | 995 | def write_tile(shp, tile): 996 | slc = shape_to_slices(shp) 997 | # Convert logits to segmentation mask and write to out 998 | out[slc] = tile.argmax(axis=-1) 999 | 1000 | if overlap is None: 1001 | overlap = tile_size[0] // 4, tile_size[1] // 4 1002 | 1003 | plan, grid = plan_regular_grid(img.shape[1:], tile_size, overlap) 1004 | run_plan_pytorch( 1005 | plan, model, read_tile, write_tile, batch_size, max_tiles, disk_cache_dir, device 1006 | ) 1007 | 1008 | return out 1009 | --------------------------------------------------------------------------------