├── .flake8 ├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── .readthedocs.yaml ├── CreateDistribution.txt ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── doc ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── ReleaseNotes.rst │ ├── conf.py │ ├── index.rst │ ├── pyshepseg_shepseg.rst │ ├── pyshepseg_subset.rst │ ├── pyshepseg_tiling.rst │ ├── pyshepseg_tilingstats.rst │ └── pyshepseg_utils.rst ├── parallel_examples ├── README.md └── awsbatch │ ├── Dockerfile │ ├── Makefile │ ├── README.md │ ├── create-stack.sh │ ├── delete-stack.sh │ ├── do_prepare.py │ ├── do_stitch.py │ ├── do_tile.py │ ├── modify-stack.sh │ ├── submit-pyshepseg-job.py │ └── template │ └── template.yaml ├── pyproject.toml └── pyshepseg ├── __init__.py ├── cmdline ├── __init__.py ├── pyshepseg_segmentationworkercmd.py ├── run_seg.py ├── runtests.py ├── subset.py ├── tiling.py └── variograms.py ├── guardeddecorators.py ├── shepseg.py ├── subset.py ├── tiling.py ├── tilingstats.py ├── timinghooks.py └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git,__pycache__,build,dist 3 | # Error codes we should ignore 4 | # W291 trailing whitespace 5 | # W293 blank line contains whitespace 6 | # W391 blank line at end of file 7 | # W504 line break after binary operator 8 | # E128 continuation line under-indented for visual indent 9 | # E225 missing whitespace around operator 10 | # E228 missing whitespace around modulo operator 11 | # E501 line too long (??? > 79 characters) 12 | ignore = W291,W293,W391,W504,E128,E225,E228,E501 13 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - uses: conda-incubator/setup-miniconda@v2 20 | with: 21 | miniforge-version: latest 22 | - name: Install dependencies 23 | shell: bash -l {0} 24 | run: | 25 | conda install flake8 numba scikit-learn gdal libgdal-kea 26 | - name: Lint with flake8 27 | shell: bash -l {0} 28 | run: | 29 | flake8 30 | - name: Test with pyshepseg_runtests 31 | shell: bash -l {0} 32 | run: | 33 | pip install . 34 | pyshepseg_runtests 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.o 3 | *.so 4 | *~ 5 | *.png 6 | *.img 7 | *.pdf 8 | *.gdb 9 | *.egg-info 10 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.12" 12 | 13 | # Build documentation in the doc/ directory with Sphinx 14 | sphinx: 15 | configuration: doc/source/conf.py 16 | 17 | python: 18 | install: 19 | - requirements: doc/requirements.txt 20 | -------------------------------------------------------------------------------- /CreateDistribution.txt: -------------------------------------------------------------------------------- 1 | How to create a distribution of pyshepseg. You will need the 'build' module 2 | installed ("conda install build" or "pip install build"). 3 | 4 | 1. Ensure that you have fetched and committed everything which needs to go in. 5 | 2. Change the version number in the pyshepseg/__init__.py. Version number 6 | is of the form a.b.c, as discussed below. 7 | 3. Update the release notes page doc/source/ReleaseNotes.rst, by going through the 8 | change logs since the last release, and noting what has been done. 9 | DON'T FORGET TO COMMIT THIS, BEFORE THE NEXT STEP!!!! 10 | 4. Push the changes to github with "git push". 11 | 12 | In practice, steps 2-3 are usually done as a single pull request, and 13 | merged, rather than pushed directly, but I am skipping all the detail 14 | of how to do a PR. 15 | 5. Check out a clean copy of the repository into /tmp or 16 | somewhere similar and 'cd' into it. 17 | 6. Create the distribution tar.gz, using 18 | python -m build . 19 | This creates a tar.gz file, under a subdirectory called dist 20 | e.g. pyshepseg-1.2.3.tar.gz 21 | 7. Create a checksum of this, e.g. 22 | sha256sum pyshepseg-1.2.3.tar.gz > pyshepseg-1.2.3.tar.gz.sha256 23 | 8. Go to the https://github.com/ubarsc/pyshepseg/releases page, and create a 24 | new release by pressing "Draft a new release". 25 | You should fill in the following: 26 | Tag version: pyshepseg-A.B.C 27 | Release Title: Version A.B.C 28 | Description: Add a brief description (a few lines at most) explaining 29 | the key points about this release. 30 | Upload files: Add the tar.gz file and the checksum. 31 | Click "Publish release" 32 | 33 | 34 | 35 | Version Numbers. 36 | The pyshepseg version number is structured as A.B.C. We follow the conventions 37 | outlined in Semantic Versioning [https://semver.org] 38 | - The A number should change for major alterations, most particularly those 39 | which break backward compatability, or which involve major restructuring of 40 | code or data structures. 41 | - The B number should change for introduction of significant new features 42 | - The C number should change for bug fixes or very minor changes. 43 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2021 Neil Flood and Sam Gillingham. All rights reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person 4 | obtaining a copy of this software and associated documentation 5 | files (the "Software"), to deal in the Software without restriction, 6 | including without limitation the rights to use, copy, modify, 7 | merge, publish, distribute, sublicense, and/or sell copies of the 8 | Software, and to permit persons to whom the Software is furnished 9 | to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 16 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 17 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR 18 | ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 19 | CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE.txt 2 | include CHANGES.txt 3 | include README.md 4 | include parallel_examples/* parallel_examples/awsbatch/* 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pyshepseg 2 | Python implementation of image segmentation algorithm of 3 | Shepherd et al (2019). Operational Large-Scale Segmentation of Imagery 4 | Based on Iterative Elimination. [Remote Sensing 11(6)](https://www.mdpi.com/2072-4292/11/6/658). 5 | 6 | This package is a tool for Python programmers to implement the segmentation 7 | algorithm. It is not a stand-alone solution for people with no Python 8 | experience. 9 | 10 | We thank the authors of the paper for their algorithm. This implementation 11 | was created independently of them, and they are in no way to blame for 12 | any mistakes we have made. 13 | 14 | Please see full documentation at [https://www.pyshepseg.org](https://www.pyshepseg.org). 15 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /doc/requirements.txt: -------------------------------------------------------------------------------- 1 | numpydoc 2 | numba 3 | -------------------------------------------------------------------------------- /doc/source/ReleaseNotes.rst: -------------------------------------------------------------------------------- 1 | Pyshepseg Release Notes 2 | ======================= 3 | 4 | Version 2.0.3 (2024-12-06) 5 | -------------------------- 6 | 7 | New Features: 8 | * Support numpy 2.0 9 | * Use multiple RIOS readworkers while calculating statistics if recent RIOS available. 10 | 11 | Bug fixes: 12 | * Allow GDAL dataset to be returned from doTiledShepherdSegmentation and histogram to be written to file by default. 13 | * More functions now support receiving a GDAL dataset as input 14 | * Add --noremove option to tiling jobs 15 | 16 | Version 2.0.2 (2024-06-12) 17 | -------------------------- 18 | 19 | New Features: 20 | * Setup is now fully controlled by pyproject.toml, with no setup.py 21 | * Add support for spatial stats within AWS Batch 22 | * Add --tileprefix in AWS Batch so all temporary S3 files are unique. 23 | This allows multiple concurrent runs. 24 | 25 | Bug fixes: 26 | * Add guard against subsampling >100% of the data 27 | * Fix console_scripts syntax 28 | 29 | Version 2.0.1 (2024-05-21) 30 | -------------------------- 31 | 32 | New Features: 33 | * Many fixes and improvements to the AWS Batch support. Most notably, 34 | statistics can now be calculated before the "stitch" job finishes. 35 | 36 | Bug Fixes: 37 | * Fix tiling code with recent scipy (>1.9.0). 38 | 39 | Version 2.0.0 (2023-01-04) 40 | -------------------------- 41 | 42 | New Features: 43 | * A test script (pyshepseg_runtests) that can be run to confirm 44 | the install is working as intended. 45 | * Split up the parts of doTiledSegmentation() so they can be run 46 | in parallel. 47 | * Check syntax with flake8 and run test script on new PRs in github. 48 | * Use entry points for the command line scripts rather than creating 49 | our own. Should make running on Windows easier. 50 | * Added ability to calculate "spatial" statistics on the segments. 51 | * Use numpydoc for creating Sphinx documentation. 52 | * Subset functionality is now in a separate "subset" module. 53 | * Statistics functionality now in a new "tilingstats" module. 54 | 55 | Version 1.1.0 (2021-12-24) 56 | -------------------------- 57 | 58 | Bug Fixes: 59 | * Guard against Float images being used for calculating 60 | statistics as the results were undefined. 61 | * Added other checks to ensure that the image having statistics 62 | calculated matches spatially with the segmented image. 63 | * Add the ability to add GDAL driver creation options for the 64 | output image of a segmentation. 65 | * Create the histogram column as a Real to match common GDAL 66 | usage. 67 | * Add checks to ensure histogram and colour columns aren't 68 | created if they already exist. 69 | * Ensure the first segment of each RAT Page isn't initally set 70 | to 'complete' before use. 71 | * Raise error if any incomplete RAT Pages are found during processing 72 | as this indicates the RAT contains more entries than unique values 73 | in the image. 74 | * When calculating statistics, open the file that the stats are 75 | calculated on in read only mode so /vsi filesystems can be used. 76 | * Increase default overlap for tiled segmentation as the old value 77 | could result in inconsistencies and segments that were missing from 78 | the image (but in the RAT). 79 | * Remove dependency on distutils which is now deprecated. 80 | 81 | New Features: 82 | * New Sphinx documentation located at https://www.pyshepseg.org. 83 | * Added a new subsetImage() function to the tiling module that subsets 84 | an already segemented image and collapses the RAT so there are no 85 | redundant entries. Also added a test_pyshepseg_subset.py command line 86 | program to test this functionality. 87 | * Exclude any nodata pixels values during statistics calculation. 88 | 89 | Version 1.0.0 (2021-04-08) 90 | -------------------------- 91 | 92 | New Features: 93 | * Added pyshsep.tiling module to allow processing of large rasters 94 | in a memory-efficient manner. 95 | * Added pyshepseg.tiling.calcPerSegmentStatsTiled() function to 96 | enable calculation of per-segment statistics in a fast and 97 | memory-efficient manner. 98 | * Added pyshepseg.utils.writeColorTableFromRatColumns() function, to 99 | add colour table calculated from per-segment statistics 100 | 101 | Version 0.1 102 | ----------- 103 | 104 | Initial implementation of segmentation algorithm. Other facilities 105 | will be added as we get to them. 106 | -------------------------------------------------------------------------------- /doc/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../..')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'pyshepseg' 21 | copyright = '2021, Neil Flood & Sam Gillingham' 22 | author = 'Neil Flood & Sam Gillingham' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | 'sphinx.ext.autodoc', 32 | 'numpydoc' 33 | ] 34 | numpydoc_show_class_members = False 35 | 36 | # Add any paths that contain templates here, relative to this directory. 37 | templates_path = ['_templates'] 38 | 39 | # List of patterns, relative to source directory, that match files and 40 | # directories to ignore when looking for source files. 41 | # This pattern also affects html_static_path and html_extra_path. 42 | exclude_patterns = [] 43 | 44 | 45 | # -- Options for HTML output ------------------------------------------------- 46 | 47 | # The theme to use for HTML and HTML Help pages. See the documentation for 48 | # a list of builtin themes. 49 | # 50 | html_theme = 'classic' 51 | html_theme_options = { 52 | "sidebarwidth": "20%", 53 | "body_min_width": "90%", 54 | "stickysidebar": True 55 | } 56 | 57 | # Add any paths that contain custom static files (such as style sheets) here, 58 | # relative to this directory. They are copied after the builtin static files, 59 | # so a file named "default.css" will overwrite the builtin "default.css". 60 | # html_static_path = ['_static'] 61 | 62 | # Set up list of things to mock, if they are not actually present. In other 63 | # words, don't mock them when they are present. This is mainly to avoid 64 | # a list of warning messages coming from Sphinx while testing, but 65 | # makes no real difference when running on ReadTheDocs (when everything 66 | # will be mocked anyway). 67 | # I am very unsure about this. I would much prefer to remove the warnings 68 | # for real, but don't yet know how. 69 | possibleMockList = ['numpy', 'numba', 'osgeo', 'scipy', 'sklearn'] 70 | autodoc_mock_imports = [] 71 | for module in possibleMockList: 72 | try: 73 | exec('import {}'.format(module)) 74 | except ImportError: 75 | autodoc_mock_imports.append(module) 76 | -------------------------------------------------------------------------------- /doc/source/index.rst: -------------------------------------------------------------------------------- 1 | .. pyshepseg documentation master file, created by 2 | sphinx-quickstart on Mon Dec 6 11:34:41 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | pyshepseg 7 | ========= 8 | 9 | Introduction 10 | ------------ 11 | Python implementation of image segmentation algorithm of 12 | Shepherd et al (2019). *Operational Large-Scale Segmentation of Imagery Based on Iterative 13 | Elimination*. `Remote Sensing 11(6) `_. 14 | 15 | This package is a tool for Python programmers to implement the segmentation algorithm. 16 | It is not a stand-alone solution for people with no Python experience. 17 | 18 | We thank the authors of the paper for their algorithm. This implementation was created 19 | independently of them, and they are in no way to blame for any mistakes we have made. 20 | 21 | Downloads 22 | --------- 23 | 24 | From `GitHub `_. 25 | Release notes by version can be viewed in :doc:`ReleaseNotes`. 26 | 27 | Dependencies 28 | ------------ 29 | The package requires the `scikit-learn `_ package, 30 | and the `numba `_ package. These need to be installed 31 | before this package will run. These are installed automatically when using the conda-forge 32 | ``pyshepseg`` package (see below), but will need to be available when building from source. 33 | 34 | Also recommended is the `GDAL `_ package for reading and 35 | writing raster file formats. It is not required by the core segmentation 36 | algorithm, but is highly recommended as a portable way to interface 37 | to a large range of raster formats. It is required by the ``tiling`` module 38 | to support segmentation of large rasters. This is installed when using 39 | the conda-forge ``pyshepseg`` package. 40 | 41 | Installation 42 | ------------ 43 | 44 | This package can be installed from conda-forge and is the recommended approach. 45 | Once you have installed `Conda `_ run 46 | the following commands to install ``pyshepseg`` into a new environment: 47 | 48 | :: 49 | 50 | conda config --add channels conda-forge 51 | conda config --set channel_priority strict 52 | conda create -n mysegenv pyshepseg 53 | conda activate mysegenv 54 | 55 | Alternatively, this package can be installed directly from the source, using the 56 | setup.py script (see required dependencies above). 57 | 58 | 1. The source code is available from ``_. 59 | Either unpack the latest release bundle from 60 | ``_, or clone the 61 | repository. 62 | 2. Run the setup.py script. This is best done by using pip as a wrapper 63 | around it, as follows. Note that pip has a ``--prefix`` option to allow 64 | installation in non-standard locations. 65 | 66 | :: 67 | 68 | pip install . 69 | 70 | 71 | Usage 72 | ----- 73 | 74 | :: 75 | 76 | from pyshepseg import shepseg 77 | 78 | # Read in a multi-band image as a single array, img, 79 | # of shape (nBands, nRows, nCols). 80 | # Ensure that any null pixels are all set to a known 81 | # null value in all bands. Failure to correctly identify 82 | # null pixels can result in a poorer quality segmentation. 83 | 84 | segRes = shepseg.doShepherdSegmentation(img, imgNullVal=nullVal) 85 | 86 | 87 | The segimg attribute of the segRes object is an array 88 | of segment ID numbers, of shape (nRows, nCols). 89 | 90 | See the help in the ``pyshepseg.shepseg`` module and :func:`pyshepseg.shepseg.doShepherdSegmentation` 91 | function for further details and tips. 92 | 93 | Large Rasters 94 | ------------- 95 | The basic usage outlined above operates entirely in-memory. For 96 | very large rasters, this can be infeasible. A tiled implementation 97 | is provided in the ``pyshepseg.tiling`` module, which divides a large 98 | raster into fixed-size tiles, segments each tile in-memory, and 99 | stitched the results together to create a single segment image. The 100 | tiles are stitched such that segments are matched and merged across 101 | tile boundaries, so the result is seamless. 102 | 103 | This technique should be used with caution. See the docstring for 104 | the ``pyshepseg.tiling`` module and the :func:`pyshepseg.tiling.doTiledShepherdSegmentation` 105 | function for further discussion of usage and caveats. 106 | 107 | Once a segmentation has been completed, statistics can be gathered per segment on 108 | large rasters using the functions defined in the ``pyshepseg.tilingstats`` 109 | module. 110 | 111 | Command Line Scripts 112 | -------------------- 113 | A few basic command line scripts are also provided as entry points. 114 | Their main purpose is as test scripts during development, but they also serve 115 | as examples of how to write scripts which use the package. In addition, 116 | they can also be used directly for simple segmentation tasks. 117 | 118 | The ``pyshepseg_run_seg`` entry point is a wrapper around the basic in-memory usage. 119 | 120 | The ``pyshepseg_tiling`` entry point is a wrapper around the tiled 121 | segmentation for large rasters. 122 | 123 | The ``pyshepseg_subset`` entry point uses the :func:`pyshepseg.subset.subsetImage` 124 | function to subset a segmentation image, re-labelling the segments 125 | to contiguous segment ID numbers. 126 | 127 | The ``pyshepseg_variograms`` entry point uses the 128 | :func:`pyshepseg.tilingstats.calcPerSegmentSpatialStatsTiled` function to calculate the 129 | given number of variograms. 130 | 131 | The ``pyshepseg_runtests`` entry point runs some tests on packages data and 132 | can be used to confirm that the behaviour of this package is as expected. 133 | 134 | Use the ``--help`` option on each script for usage details. 135 | 136 | Per-segment Statistics 137 | ---------------------- 138 | It can be useful to calculate statistics of the pixels from 139 | the original input imagery on a per-segment basis. For example, for 140 | all the pixels in a single segment, one might calculate the mean value 141 | of one or more of the bands from the original imagery. 142 | 143 | A routine is provided to do this in a memory-efficient way, given the 144 | original image and the completed segmentation image. A standard set of 145 | statistics are available, including mean, standard deviation, and 146 | arbitrary percentile values, amongst others. The selected per-segment 147 | statistics are written to the segment image file as columns of a raster 148 | attribute table (RAT). 149 | 150 | For details, see the help on the :func:`pyshepseg.tilingstats.calcPerSegmentStatsTiled` 151 | and :func:`pyshepseg.tilingstats.calcPerSegmentSpatialStatsTiled` function. 152 | 153 | Segment Colour Tables 154 | --------------------- 155 | The segment image contains a large number of individual segment values, and 156 | can be difficult to view in simple greyscale colouring. To improve this, two 157 | routines are provided in the ``pyshepseg.utils`` module which will attach a colour table. 158 | 159 | The simplest routine is :func:`pyshepseg.utils.writeRandomColourTable`, which attaches a 160 | randomly-generated colour table, so that each segment is assigned a randomly 161 | chosen colour, which merely serves to distinguish it from the surrounding segments. 162 | See its help for details. 163 | 164 | More sophisticated and more useful is the function :func:`pyshepseg.utils.writeColorTableFromRatColumns`, 165 | which uses previously calculated columns in the raster attribute table (RAT) to 166 | create a colour table which approximates the original imagery. See its help for 167 | details, and the preceding section on how to create suitable RAT columns. 168 | 169 | Subsetting 170 | ---------- 171 | For large segmentations sometimes it is necessary to subset the result into a smaller 172 | image so it is easier to work with, but have contiguous segment ids and a link back to the 173 | original segments. For doing this, see the ``pyshepseg.subset`` module and the 174 | :func:`pyshepseg.subset.subsetImage` function. 175 | 176 | 177 | Modules in this Package 178 | ======================= 179 | 180 | .. toctree:: 181 | :maxdepth: 1 182 | 183 | pyshepseg_shepseg 184 | pyshepseg_tiling 185 | pyshepseg_tilingstats 186 | pyshepseg_utils 187 | pyshepseg_subset 188 | 189 | 190 | Indices and tables 191 | ================== 192 | 193 | * :ref:`genindex` 194 | * :ref:`modindex` 195 | * :ref:`search` 196 | -------------------------------------------------------------------------------- /doc/source/pyshepseg_shepseg.rst: -------------------------------------------------------------------------------- 1 | :tocdepth: 2 2 | 3 | shepseg 4 | ======= 5 | .. automodule:: pyshepseg.shepseg 6 | :members: 7 | :undoc-members: 8 | 9 | * :ref:`genindex` 10 | * :ref:`modindex` 11 | * :ref:`search` 12 | -------------------------------------------------------------------------------- /doc/source/pyshepseg_subset.rst: -------------------------------------------------------------------------------- 1 | :tocdepth: 2 2 | 3 | subset 4 | ====== 5 | .. automodule:: pyshepseg.subset 6 | :members: 7 | :undoc-members: 8 | 9 | * :ref:`genindex` 10 | * :ref:`modindex` 11 | * :ref:`search` 12 | -------------------------------------------------------------------------------- /doc/source/pyshepseg_tiling.rst: -------------------------------------------------------------------------------- 1 | :tocdepth: 2 2 | 3 | tiling 4 | ======= 5 | .. automodule:: pyshepseg.tiling 6 | :members: 7 | :undoc-members: 8 | 9 | * :ref:`genindex` 10 | * :ref:`modindex` 11 | * :ref:`search` 12 | -------------------------------------------------------------------------------- /doc/source/pyshepseg_tilingstats.rst: -------------------------------------------------------------------------------- 1 | :tocdepth: 2 2 | 3 | tilingstats 4 | =========== 5 | .. automodule:: pyshepseg.tilingstats 6 | :members: 7 | :undoc-members: 8 | 9 | * :ref:`genindex` 10 | * :ref:`modindex` 11 | * :ref:`search` 12 | -------------------------------------------------------------------------------- /doc/source/pyshepseg_utils.rst: -------------------------------------------------------------------------------- 1 | :tocdepth: 2 2 | 3 | utils 4 | ======= 5 | .. automodule:: pyshepseg.utils 6 | :members: 7 | :undoc-members: 8 | 9 | * :ref:`genindex` 10 | * :ref:`modindex` 11 | * :ref:`search` 12 | -------------------------------------------------------------------------------- /parallel_examples/README.md: -------------------------------------------------------------------------------- 1 | # Parallel Examples 2 | 3 | Under this directory are examples of running the tiled segmentation in 4 | parallel. 5 | -------------------------------------------------------------------------------- /parallel_examples/awsbatch/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/ubuntu/ubuntu:22.04 2 | 3 | # Needed in case tzdata gets upgraded 4 | ENV TZ=Australia/Brisbane 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | 7 | ARG AWS_REGION=us-west-2 8 | ENV AWS_REGION_ENV=$AWS_REGION 9 | 10 | ARG PYSHEPSEG_VER 11 | 12 | # use local mirror 13 | RUN sed -i "s/http:\/\/ports./http:\/\/${AWS_REGION_ENV}.ec2.ports./g" /etc/apt/sources.list 14 | RUN sed -i "s/http:\/\/archive./http:\/\/${AWS_REGION_ENV}.ec2.archive./g" /etc/apt/sources.list 15 | 16 | # Update Ubuntu software stack and install base GDAL stack 17 | RUN apt-get update 18 | RUN apt-get upgrade -y 19 | RUN apt-get install -y python3-gdal python3-boto3 python3-sklearn \ 20 | python3-numba wget g++ cmake libhdf5-dev libgdal-dev python3-pip 21 | 22 | ENV SW_VOLUME=/ubarscsw 23 | RUN mkdir $SW_VOLUME 24 | 25 | ENV SERVICEUSER=ubarscuser 26 | RUN useradd --create-home --shell /bin/bash ${SERVICEUSER} 27 | 28 | ENV KEALIB_VERSION=1.5.3 29 | RUN cd /tmp \ 30 | && wget -q https://github.com/ubarsc/kealib/releases/download/kealib-${KEALIB_VERSION}/kealib-${KEALIB_VERSION}.tar.gz \ 31 | && tar xf kealib-${KEALIB_VERSION}.tar.gz \ 32 | && cd kealib-${KEALIB_VERSION} \ 33 | && mkdir build \ 34 | && cd build \ 35 | && cmake -D CMAKE_INSTALL_PREFIX=${SW_VOLUME} -D LIBKEA_WITH_GDAL=ON .. \ 36 | && make \ 37 | && make install \ 38 | && cd ../.. \ 39 | && rm -rf kealib-${KEALIB_VERSION} kealib-${KEALIB_VERSION}.tar.gz 40 | 41 | ENV RIOS_VERSION=2.0.3 42 | RUN cd /tmp \ 43 | && wget -q https://github.com/ubarsc/rios/releases/download/rios-${RIOS_VERSION}/rios-${RIOS_VERSION}.tar.gz \ 44 | && tar xf rios-${RIOS_VERSION}.tar.gz \ 45 | && cd rios-${RIOS_VERSION} \ 46 | && DEB_PYTHON_INSTALL_LAYOUT=deb_system pip install . \ 47 | && cd .. \ 48 | && rm -rf rios-${RIOS_VERSION} rios-${RIOS_VERSION}.tar.gz 49 | 50 | COPY pyshepseg-$PYSHEPSEG_VER.tar.gz /tmp 51 | # install pyshegseg 52 | RUN cd /tmp && tar xf pyshepseg-$PYSHEPSEG_VER.tar.gz \ 53 | && cd pyshepseg-$PYSHEPSEG_VER \ 54 | && DEB_PYTHON_INSTALL_LAYOUT=deb_system pip install . \ 55 | && cd .. && rm -rf pyshepseg-$PYSHEPSEG_VER pyshepseg-$PYSHEPSEG_VER.tar.gz 56 | 57 | 58 | ENV LD_LIBRARY_PATH=${SW_VOLUME}/lib 59 | ENV GDAL_DRIVER_PATH=${SW_VOLUME}/lib/gdalplugins 60 | 61 | ENV PYTHONUNBUFFERED=1 62 | 63 | ENV GDAL_PAM_ENABLED=NO 64 | ENV GDAL_CACHEMAX=1024000000 65 | ENV GDAL_DISABLE_READDIR_ON_OPEN=EMPTY_DIR 66 | ENV GDAL_HTTP_MERGE_CONSECUTIVE_RANGES=YES 67 | ENV GDAL_HTTP_MULTIPLEX=YES 68 | ENV CPL_VSIL_CURL_ALLOWED_EXTENSIONS=".tif,.TIF,.tiff,.vrt,.zip" 69 | ENV VSI_CACHE=True 70 | ENV VSI_CACHE_SIZE=1024000000 71 | ENV GDAL_HTTP_MAX_RETRY=10 72 | ENV GDAL_HTTP_RETRY_DELAY=3 73 | ENV CPL_ZIP_ENCODING=UTF-8 74 | 75 | COPY do_prepare.py $SW_VOLUME/bin 76 | COPY do_tile.py $SW_VOLUME/bin 77 | COPY do_stitch.py $SW_VOLUME/bin 78 | 79 | RUN apt-get remove -y wget g++ cmake 80 | RUN apt-get autoremove -y && apt-get clean && rm -rf /var/lib/apt/lists/* 81 | 82 | USER $SERVICEUSER 83 | 84 | # a few quick tests 85 | RUN python3 -c 'from osgeo import gdal;assert(gdal.GetDriverByName("KEA") is not None)' 86 | RUN python3 -c 'from pyshepseg import tiling' 87 | RUN python3 -c 'from rios import applier' 88 | 89 | # export the volume 90 | VOLUME $SW_VOLUME 91 | 92 | # set the workdir to the home directory for our user (not sure if right thing to do) 93 | WORKDIR /home/${SERVICEUSER} 94 | -------------------------------------------------------------------------------- /parallel_examples/awsbatch/Makefile: -------------------------------------------------------------------------------- 1 | # A make file to create and push a docker image with RIOS and any other required packages 2 | # to ECR. 3 | # set the AWS_REGION environment variable to the name of the AWS region you wish to use 4 | 5 | ifndef AWS_REGION 6 | $(error AWS_REGION is not set) 7 | endif 8 | 9 | ACCOUNT_ID := $(shell aws sts get-caller-identity --query "Account" --output text) 10 | PYSHEPSEG_VER := $(shell python3 -c 'import pyshepseg;print(pyshepseg.__version__)') 11 | ECR_URL=${ACCOUNT_ID}.dkr.ecr.${AWS_REGION}.amazonaws.com 12 | DOCKER_TAG=pyshepseg 13 | 14 | REPO=${ECR_URL}/${DOCKER_TAG}:latest 15 | 16 | default: all 17 | 18 | # grab the current pyshepseg source tree and make it available to the 19 | # docker COPY command 20 | dist: 21 | cd ../../;python3 -m build . 22 | cp ../../dist/pyshepseg-$(PYSHEPSEG_VER).tar.gz . 23 | 24 | # Login to ECR, build package and push to ECR 25 | all: dist 26 | aws ecr get-login-password --region ${AWS_REGION} | docker login --username AWS --password-stdin $(ECR_URL) 27 | docker build --build-arg AWS_REGION=${AWS_REGION} --build-arg PYSHEPSEG_VER=$(PYSHEPSEG_VER) -t $(DOCKER_TAG) . 28 | docker tag $(DOCKER_TAG) $(REPO) 29 | docker push $(REPO) 30 | -------------------------------------------------------------------------------- /parallel_examples/awsbatch/README.md: -------------------------------------------------------------------------------- 1 | # AWS Batch 2 | 3 | The files in this folder contain a working demonstration of how 4 | to run the tiled segmentation in parallel on AWS Batch. 5 | 6 | ## Contents 7 | 8 | `submit-pyshepseg-job.py` 9 | 10 | Can be run to submit a tile segmentation to AWS Batch. CloudFormation and 11 | Docker Image must have been run beforehand. 12 | 13 | See the output of `submit-pyshepseg-job.py --help` for more information. 14 | 15 | ### CloudFormation 16 | 17 | `template/template.yaml` 18 | 19 | Contains the CloudFormation template to create the AWS Batch environment 20 | 21 | `create-stack.sh` 22 | 23 | Invokes CloudFormation to create the AWS Batch environment from `template/template.yaml` 24 | to create a CloudFormation stack called `ubarsc-parallel-seg`. 25 | 26 | `delete-stack.sh` 27 | 28 | Deletes the AWS Batch environment. 29 | 30 | `modify-stack.sh` 31 | 32 | Attempts to modify the AWS Batch environment by applying any changes to `template/template.yaml`. 33 | 34 | ### Docker 35 | 36 | AWS Batch requires a Docker image to be pushed to AWS ECR. 37 | 38 | `Dockerfile` 39 | 40 | Contains instructions for creating a Docker Image with the required software 41 | to perform the tiled segmentation. 42 | 43 | `Makefile` 44 | 45 | Builds the Docker Image from Dockerfile and pushes the Docker Image to a 46 | repository on AWS ECR called "ubarsc_parallel_seg". Note this 47 | repository is NOT created by the CloudFormation script above. 48 | 49 | ### Supporting Scripts 50 | 51 | These are copied into the Docker Image by the `Dockerfile`. 52 | 53 | `do_prepare.py` 54 | 55 | Runs `tiling.doTiledShepherdSegmentation_prepare()` and copies the resulting data to a pickle file 56 | to the specified S3 Bucket to be picked up by the following steps. 57 | 58 | This is the first time we know how many tiles there are so this script also kicks off 59 | the appropriate number of array jobs (each running `do_tile.py` - see below). It also submits a final 60 | job (dependent on all the `do_tile.py` jobs completing) that runs `do_stitch.py` (see below). 61 | 62 | `do_tile.py` 63 | 64 | Runs `tiling.doTiledShepherdSegmentation_doOne()`. It loads the required data from the saved pickle 65 | and outputs the processed tile and saves it to the specified S3 Bucket to be picked up by `do_stitch.py`. 66 | Which tile is being processed is specified by the `AWS_BATCH_JOB_ARRAY_INDEX` environment variable - 67 | refer to the AWS Batch documentation on array jobs for more information. 68 | 69 | `do_stitch.py` 70 | 71 | Runs `tiling.doTiledShepherdSegmentation_finalize()`. It loads the required data from the saved pickle 72 | and determines the names of the individual tiles that have been processed by `do_tile.py`. This 73 | generates the output (stitched) file and copies to S3. It also deletes all temporary files from S3. 74 | 75 | 76 | -------------------------------------------------------------------------------- /parallel_examples/awsbatch/create-stack.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ -z "${AWS_REGION}" ]]; then 4 | echo "Must set AWS_REGION first" 5 | exit 1 6 | fi 7 | 8 | aws cloudformation create-stack --stack-name pyshepseg-parallel \ 9 | --template-body file://template/template.yaml \ 10 | --capabilities CAPABILITY_NAMED_IAM --region $AWS_REGION \ 11 | --tags Key=PyShepSeg,Value=1 12 | -------------------------------------------------------------------------------- /parallel_examples/awsbatch/delete-stack.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ -z "${AWS_REGION}" ]]; then 4 | echo "Must set AWS_REGION first" 5 | exit 1 6 | fi 7 | 8 | aws cloudformation delete-stack --stack-name pyshepseg-parallel --region $AWS_REGION 9 | echo 'Stack Deletion in progress... Wait a few minutes' 10 | -------------------------------------------------------------------------------- /parallel_examples/awsbatch/do_prepare.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Is the first script that runs for a job submitted by submit-job.py. 5 | 6 | Runs tiling.doTiledShepherdSegmentation_prepare() then submits the 7 | other jobs required to do the tiled segmentation. 8 | 9 | 10 | """ 11 | 12 | import io 13 | import pickle 14 | import resource 15 | import argparse 16 | import boto3 17 | from pyshepseg import tiling 18 | 19 | 20 | def getCmdargs(): 21 | """ 22 | Process the command line arguments. 23 | """ 24 | p = argparse.ArgumentParser() 25 | p.add_argument("--bucket", required=True, 26 | help="S3 Bucket to use") 27 | p.add_argument("--infile", required=True, 28 | help="Path in --bucket to use as input file") 29 | p.add_argument("--outfile", required=True, 30 | help="Path in --bucket to use as output file (.kea)") 31 | p.add_argument("-b", "--bands", 32 | help="Comma seperated list of bands to use. 1-based. Uses all bands by default.") 33 | p.add_argument("--tilesize", required=True, type=int, 34 | help="Tile Size to use.") 35 | p.add_argument("--overlapsize", required=True, type=int, 36 | help="Tile Overlap to use.") 37 | p.add_argument("--tileprefix", default='tile', 38 | help="Unique prefix to save the output tiles with. (default=%(default)s)") 39 | p.add_argument("--pickle", required=True, 40 | help="name of pickle to save result of preparation into") 41 | p.add_argument("--region", default="us-west-2", 42 | help="Region to run the jobs in. (default=%(default)s)") 43 | p.add_argument("--jobqueue", default="PyShepSegBatchProcessingJobQueue", 44 | help="Name of Job Queue to use. (default=%(default)s)") 45 | p.add_argument("--jobdefntile", default="PyShepSegBatchJobDefinitionTile", 46 | help="Name of Job Definition to use for tile jobs. (default=%(default)s)") 47 | p.add_argument("--jobdefnstitch", default="PyShepSegBatchJobDefinitionStitch", 48 | help="Name of Job Definition to use for the stitch job. (default=%(default)s)") 49 | p.add_argument("--stats", help="path to json file specifying stats in format:" + 50 | "bucket:path/in/bucket.json. Contents must be a list of [img, band, " + 51 | "statsSelection] tuples.") 52 | p.add_argument("--spatialstats", help="path to json file specifying spatial " + 53 | "stats in format: bucket:path/in/bucket.json. Contents must be a list of " + 54 | "[img, band, [list of (colName, colType) tuples], name-of-userfunc, param]" + 55 | " tuples.") 56 | p.add_argument("--nogdalstats", action="store_true", default=False, 57 | help="don't calculate GDAL's statistics or write a colour table. " + 58 | "Can't be used with --stats or --spatialstats.") 59 | p.add_argument("--minSegmentSize", type=int, default=50, required=False, 60 | help="Segment size for segmentation (default=%(default)s)") 61 | p.add_argument("--numClusters", type=int, default=60, required=False, 62 | help="Number of clusters for segmentation (default=%(default)s)") 63 | p.add_argument("--maxSpectDiff", required=False, default='auto', 64 | help="Maximum spectral difference for segmentation (default=%(default)s)") 65 | p.add_argument("--spectDistPcntile", type=int, default=50, required=False, 66 | help="Spectral Distance Percentile for segmentation (default=%(default)s)") 67 | p.add_argument("--noremove", action="store_true", default=False, 68 | help="don't remove files from S3 (for debugging)") 69 | p.add_argument("--statsreadworkers", type=int, default=0, 70 | help="Number or RIOS readworkers to use while calculating stats. " + 71 | "(default=%(default)s)") 72 | p.add_argument("--readworkerstimeouts", type=int, 73 | help="If statsreadworkers specified, this value is used for readBufferPopTimeout, " + 74 | "readBufferInsertTimeout, computeBufferInsertTimeout, computeBufferPopTimeout " + 75 | "in the RIOS ConcurrencyStyle object") 76 | p.add_argument("--kmeans", 77 | help="If specified, this should be a path to a pickled kmeans object to do " + 78 | "the segmentation with (in format:bucket:path/in/bucket.pkl). " + 79 | "--numClusters will be ignored in this case. ") 80 | 81 | cmdargs = p.parse_args() 82 | if cmdargs.bands is not None: 83 | # turn string of bands into list of ints 84 | cmdargs.bands = [int(x) for x in cmdargs.bands.split(',')] 85 | 86 | return cmdargs 87 | 88 | 89 | def main(): 90 | """ 91 | Main routine 92 | """ 93 | cmdargs = getCmdargs() 94 | 95 | # connect to Batch for submitting other jobs 96 | batch = boto3.client('batch', region_name=cmdargs.region) 97 | # connect to S3 for saving the pickled data file 98 | s3 = boto3.client('s3') 99 | 100 | # work out the path that will work for GDAL. 101 | # Note: input file is assumed to be a format that works with /vsi filesystems 102 | # ie: GTiff. 103 | inPath = '/vsis3/' + cmdargs.bucket + '/' + cmdargs.infile 104 | 105 | # did they supply a kmeans path? 106 | kmeansObj = None 107 | if cmdargs.kmeans is not None: 108 | bucket, kmeansKey = cmdargs.kmeans.split(':') 109 | with io.BytesIO() as fileobj: 110 | s3.download_fileobj(bucket, kmeansKey, fileobj) 111 | fileobj.seek(0) 112 | kmeansObj = pickle.load(fileobj) 113 | 114 | # run the initial part of the tiled segmentation 115 | inDs, bandNumbers, kmeansObj, subsamplePcnt, imgNullVal, tileInfo = ( 116 | tiling.doTiledShepherdSegmentation_prepare(inPath, 117 | bandNumbers=cmdargs.bands, tileSize=cmdargs.tilesize, 118 | overlapSize=cmdargs.overlapsize, 119 | numClusters=cmdargs.numClusters, 120 | kmeansObj=kmeansObj)) 121 | 122 | # pickle the required input data that each of the tiles will need 123 | colRowList = sorted(tileInfo.tiles.keys(), key=lambda x: (x[1], x[0])) 124 | dataToPickle = {'tileInfo': tileInfo, 'colRowList': colRowList, 125 | 'bandNumbers': bandNumbers, 'imgNullVal': imgNullVal, 126 | 'kmeansObj': kmeansObj} 127 | # pickle and upload to S3 128 | with io.BytesIO() as fileobj: 129 | pickle.dump(dataToPickle, fileobj) 130 | fileobj.seek(0) 131 | s3.upload_fileobj(fileobj, cmdargs.bucket, cmdargs.pickle) 132 | 133 | # now submit an array job with all the tiles 134 | # (can't do this before now because we don't know how many tiles) 135 | containerOverrides = { 136 | "command": ['/usr/bin/python3', '/ubarscsw/bin/do_tile.py', 137 | '--bucket', cmdargs.bucket, '--pickle', cmdargs.pickle, 138 | '--infile', cmdargs.infile, '--tileprefix', cmdargs.tileprefix, 139 | '--minSegmentSize', str(cmdargs.minSegmentSize), 140 | '--maxSpectDiff', cmdargs.maxSpectDiff, 141 | '--spectDistPcntile', str(cmdargs.spectDistPcntile)]} 142 | 143 | arrayProperties = {} 144 | if len(colRowList) > 1: 145 | # throws error if this is 1... 146 | arrayProperties['size'] = len(colRowList) 147 | else: 148 | # must fake AWS_BATCH_JOB_ARRAY_INDEX 149 | # can't set this as and env var as Batch overrides 150 | containerOverrides['command'].extend(['--arrayindex', '0']) 151 | 152 | response = batch.submit_job(jobName="pyshepseg_tiles", 153 | jobQueue=cmdargs.jobqueue, 154 | jobDefinition=cmdargs.jobdefntile, 155 | arrayProperties=arrayProperties, 156 | containerOverrides=containerOverrides) 157 | tilesJobId = response['jobId'] 158 | print('Tiles Job Id', tilesJobId) 159 | 160 | # now submit a dependent job with the stitching 161 | # this one only runs when the array jobs are all done 162 | cmd = ['/usr/bin/python3', '/ubarscsw/bin/do_stitch.py', 163 | '--bucket', cmdargs.bucket, '--outfile', cmdargs.outfile, 164 | '--tileprefix', cmdargs.tileprefix, 165 | '--infile', cmdargs.infile, '--pickle', cmdargs.pickle, 166 | '--overlapsize', str(cmdargs.overlapsize), 167 | '--statsreadworkers', str(cmdargs.statsreadworkers)] 168 | if cmdargs.stats is not None: 169 | cmd.extend(['--stats', cmdargs.stats]) 170 | if cmdargs.spatialstats is not None: 171 | cmd.extend(['--spatialstats', cmdargs.spatialstats]) 172 | if cmdargs.nogdalstats: 173 | cmd.append('--nogdalstats') 174 | if cmdargs.noremove: 175 | cmd.append('--noremove') 176 | if cmdargs.readworkerstimeouts is not None: 177 | cmd.extend(['--readworkerstimeouts', str(cmdargs.readworkerstimeouts)]) 178 | 179 | response = batch.submit_job(jobName="pyshepseg_stitch", 180 | jobQueue=cmdargs.jobqueue, 181 | jobDefinition=cmdargs.jobdefnstitch, 182 | dependsOn=[{'jobId': tilesJobId}], 183 | containerOverrides={ 184 | "command": cmd}) 185 | print('Stitching Job Id', response['jobId']) 186 | maxMem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss 187 | print('Max Mem Usage', maxMem) 188 | 189 | 190 | if __name__ == '__main__': 191 | main() 192 | -------------------------------------------------------------------------------- /parallel_examples/awsbatch/do_stitch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Script that stiches all the tiles together by calling 5 | tiling.doTiledShepherdSegmentation_finalize(). 6 | 7 | Uplaods the resulting segmentation to S3. 8 | """ 9 | 10 | import io 11 | import os 12 | import json 13 | import pickle 14 | import resource 15 | import argparse 16 | import tempfile 17 | import shutil 18 | import importlib 19 | import boto3 20 | from pyshepseg import tiling, tilingstats, utils 21 | from osgeo import gdal 22 | from rios import applier 23 | 24 | gdal.UseExceptions() 25 | 26 | 27 | def getCmdargs(): 28 | """ 29 | Process the command line arguments. 30 | """ 31 | p = argparse.ArgumentParser() 32 | p.add_argument("--bucket", required=True, 33 | help="S3 Bucket to use") 34 | p.add_argument("--infile", required=True, 35 | help="Path in --bucket to use as input file") 36 | p.add_argument("--outfile", required=True, 37 | help="Path in --bucket to use as output file (.kea)") 38 | p.add_argument("--tileprefix", required=True, 39 | help="Unique prefix to save the output tiles with.") 40 | p.add_argument("--pickle", required=True, 41 | help="name of pickle with the result of the preparation") 42 | p.add_argument("--overlapsize", required=True, type=int, 43 | help="Tile Overlap to use. (default=%(default)s)") 44 | p.add_argument("--stats", help="path to json file specifying stats in format:" + 45 | "bucket:path/in/bucket.json. Contents must be a list of [img, band, " + 46 | "statsSelection] tuples.") 47 | p.add_argument("--spatialstats", help="path to json file specifying spatial " + 48 | "stats in format: bucket:path/in/bucket.jso. Contents must be a list of " + 49 | "[img, band, [list of (colName, colType) tuples], name-of-userfunc, param]" + 50 | " tuples.") 51 | p.add_argument("--nogdalstats", action="store_true", default=False, 52 | help="don't calculate GDAL's statistics or write a colour table. " + 53 | "Can't be used with --stats.") 54 | p.add_argument("--noremove", action="store_true", default=False, 55 | help="don't remove files from S3 (for debugging)") 56 | p.add_argument("--statsreadworkers", type=int, default=0, 57 | help="Number or RIOS readworkers to use while calculating stats. " + 58 | "(default=%(default)s)") 59 | p.add_argument("--readworkerstimeouts", type=int, 60 | help="If statsreadworkers specified, this value is used for readBufferPopTimeout, " + 61 | "readBufferInsertTimeout, computeBufferInsertTimeout, computeBufferPopTimeout " + 62 | "in the RIOS ConcurrencyStyle object") 63 | 64 | cmdargs = p.parse_args() 65 | 66 | return cmdargs 67 | 68 | 69 | def main(): 70 | """ 71 | Main routine 72 | """ 73 | cmdargs = getCmdargs() 74 | 75 | # download the pickled data and unpickle. 76 | s3 = boto3.client('s3') 77 | with io.BytesIO() as fileobj: 78 | s3.download_fileobj(cmdargs.bucket, cmdargs.pickle, fileobj) 79 | fileobj.seek(0) 80 | 81 | dataFromPickle = pickle.load(fileobj) 82 | 83 | # work out GDAL path to input file and open it 84 | inPath = '/vsis3/' + cmdargs.bucket + '/' + cmdargs.infile 85 | inDs = gdal.Open(inPath) 86 | 87 | tempDir = tempfile.mkdtemp() 88 | 89 | # work out what the tiles would have been named 90 | # Note: this needs to match do_tile.py. 91 | tileFilenames = {} 92 | for col, row in dataFromPickle['colRowList']: 93 | filename = '/vsis3/' + cmdargs.bucket + '/' + '{}_{}_{}.{}'.format( 94 | cmdargs.tileprefix, col, row, 'tif') 95 | tileFilenames[(col, row)] = filename 96 | 97 | # save the KEA file to the local path first 98 | localOutfile = os.path.join(tempDir, os.path.basename(cmdargs.outfile)) 99 | 100 | # do the stitching. Note maxSegId and hasEmptySegments not used here 101 | # but ideally they would be saved somewhere also. 102 | # Ensure histogram written to local file so we can do the statistics 103 | (maxSegId, hasEmptySegments, localDs) = tiling.doTiledShepherdSegmentation_finalize( 104 | inDs, localOutfile, tileFilenames, dataFromPickle['tileInfo'], 105 | cmdargs.overlapsize, tempDir, writeHistogram=True) 106 | 107 | # clean up files to release space 108 | if not cmdargs.noremove: 109 | objs = [] 110 | for col, row in tileFilenames: 111 | filename = '{}_{}_{}.{}'.format(cmdargs.tileprefix, col, row, 'tif') 112 | objs.append({'Key': filename}) 113 | 114 | # workaround 1000 at a time limit 115 | while len(objs) > 0: 116 | s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs[0:1000]}) 117 | del objs[0:1000] 118 | 119 | if not cmdargs.nogdalstats: 120 | band = localDs.GetRasterBand(1) 121 | # Histogram should be already written by doTiledShepherdSegmentation_finalize 122 | # above 123 | rat = band.GetDefaultRAT() 124 | histIdx = rat.GetColOfUsage(gdal.GFU_PixelCount) 125 | hist = rat.ReadAsArray(histIdx) 126 | 127 | utils.estimateStatsFromHisto(band, hist) 128 | utils.writeRandomColourTable(band, maxSegId + 1) 129 | utils.addOverviews(localDs) 130 | 131 | # ensure dataset is closed so we can open it again in RIOS 132 | del localDs 133 | 134 | if cmdargs.readworkerstimeouts is not None: 135 | concurrencyStyle = applier.ConcurrencyStyle( 136 | numReadWorkers=cmdargs.statsreadworkers, 137 | readBufferPopTimeout=cmdargs.readworkerstimeouts, 138 | readBufferInsertTimeout=cmdargs.readworkerstimeouts, 139 | computeBufferInsertTimeout=cmdargs.readworkerstimeouts, 140 | computeBufferPopTimeout=cmdargs.readworkerstimeouts) 141 | else: 142 | concurrencyStyle = applier.ConcurrencyStyle( 143 | numReadWorkers=cmdargs.statsreadworkers) 144 | 145 | # now do any stats the user has asked for 146 | if cmdargs.stats is not None: 147 | 148 | bucket, statsKey = cmdargs.stats.split(':') 149 | with io.BytesIO() as fileobj: 150 | s3.download_fileobj(bucket, statsKey, fileobj) 151 | fileobj.seek(0) 152 | 153 | dataForStats = json.load(fileobj) 154 | for img, bandnum, selection in dataForStats: 155 | print(img, bandnum, selection) 156 | tilingstats.calcPerSegmentStatsRIOS(img, bandnum, 157 | localOutfile, selection, concurrencyStyle) 158 | 159 | if cmdargs.spatialstats is not None: 160 | bucket, spatialstatsKey = cmdargs.spatialstats.split(':') 161 | with io.BytesIO() as fileobj: 162 | s3.download_fileobj(bucket, spatialstatsKey, fileobj) 163 | fileobj.seek(0) 164 | 165 | dataForStats = json.load(fileobj) 166 | for img, bandnum, colInfo, userFuncName, param in dataForStats: 167 | print(img, bandnum, colInfo, userFuncName, param) 168 | userFuncArr = userFuncName.split('.') 169 | if len(userFuncArr) < 2: 170 | raise ValueError("'userFunc' must be a fully qualified function " + 171 | "name. ie. modulename.function_name. " + 172 | "eg. pyshepseg.tilingstats.userFuncVariogram") 173 | 174 | moduleName = '.'.join(userFuncArr[:-1]) 175 | funcName = userFuncArr[-1] 176 | mod = importlib.import_module(moduleName) 177 | if not hasattr(mod, funcName): 178 | raise ValueError(f"Cannot find function {funcName} " + 179 | f"in module {moduleName}") 180 | 181 | userFunc = getattr(mod, funcName) 182 | 183 | tilingstats.calcPerSegmentSpatialStatsRIOS(img, bandnum, 184 | localOutfile, colInfo, userFunc, param, concurrencyStyle) 185 | 186 | # upload the KEA file 187 | s3.upload_file(localOutfile, cmdargs.bucket, cmdargs.outfile) 188 | 189 | # cleanup temp files from S3 190 | if not cmdargs.noremove: 191 | objs = [{'Key': cmdargs.pickle}] 192 | if cmdargs.stats is not None: 193 | objs.append({'Key': statsKey}) 194 | if cmdargs.spatialstats is not None: 195 | objs.append({'Key': spatialstatsKey}) 196 | 197 | s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs}) 198 | 199 | # cleanup 200 | shutil.rmtree(tempDir) 201 | maxMem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss 202 | print('Max Mem Usage', maxMem) 203 | 204 | 205 | if __name__ == '__main__': 206 | main() 207 | -------------------------------------------------------------------------------- /parallel_examples/awsbatch/do_tile.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Process an individual tile as part of a tiled segmentation. Indexes 5 | into the pickled colRowList with the AWS_BATCH_JOB_ARRAY_INDEX env var 6 | (set by AWS Batch for array jobs). 7 | 8 | """ 9 | 10 | import io 11 | import os 12 | import pickle 13 | import argparse 14 | import tempfile 15 | import resource 16 | import shutil 17 | import boto3 18 | from pyshepseg import tiling 19 | 20 | from osgeo import gdal 21 | 22 | gdal.UseExceptions() 23 | 24 | 25 | def getCmdargs(): 26 | """ 27 | Process the command line arguments. 28 | """ 29 | p = argparse.ArgumentParser() 30 | p.add_argument("--bucket", required=True, 31 | help="S3 Bucket to use") 32 | p.add_argument("--infile", required=True, 33 | help="Path in --bucket to use as input file") 34 | p.add_argument("--tileprefix", required=True, 35 | help="Unique prefix to save the output tiles with.") 36 | p.add_argument("--pickle", required=True, 37 | help="name of pickle with the result of the preparation") 38 | p.add_argument("--minSegmentSize", type=int, default=50, required=False, 39 | help="Segment size for segmentation (default=%(default)s)") 40 | p.add_argument("--maxSpectDiff", required=False, default='auto', 41 | help="Maximum spectral difference for segmentation (default=%(default)s)") 42 | p.add_argument("--spectDistPcntile", type=int, default=50, required=False, 43 | help="Spectral Distance Percentile for segmentation (default=%(default)s)") 44 | p.add_argument("--arrayindex", type=int, 45 | help="Override AWS_BATCH_JOB_ARRAY_INDEX env var") 46 | 47 | cmdargs = p.parse_args() 48 | 49 | if cmdargs.arrayindex is None: 50 | arrayindex = os.getenv('AWS_BATCH_JOB_ARRAY_INDEX') 51 | if arrayindex is None: 52 | raise SystemExit('Must set AWS_BATCH_JOB_ARRAY_INDEX env var or ' + 53 | 'specify --arrayindex') 54 | else: 55 | cmdargs.arrayindex = int(arrayindex) 56 | 57 | return cmdargs 58 | 59 | 60 | def main(): 61 | """ 62 | Main routine 63 | """ 64 | cmdargs = getCmdargs() 65 | 66 | # download pickle file and un-pickle it 67 | s3 = boto3.client('s3') 68 | with io.BytesIO() as fileobj: 69 | s3.download_fileobj(cmdargs.bucket, cmdargs.pickle, fileobj) 70 | fileobj.seek(0) 71 | 72 | dataFromPickle = pickle.load(fileobj) 73 | 74 | # work out GDAL path to input file and open it 75 | inPath = '/vsis3/' + cmdargs.bucket + '/' + cmdargs.infile 76 | inDs = gdal.Open(inPath) 77 | 78 | tempDir = tempfile.mkdtemp() 79 | 80 | # work out which tile we are processing 81 | col, row = dataFromPickle['colRowList'][cmdargs.arrayindex] 82 | 83 | # work out a filename to save with the output of this tile 84 | # Note: this filename format is repeated in do_stitch.py 85 | # - they must match. Potentially a database or similar 86 | # could have been used to notify of the names of tiles 87 | # but this would add more complexity. 88 | filename = '{}_{}_{}.{}'.format(cmdargs.tileprefix, 89 | col, row, 'tif') 90 | filename = os.path.join(tempDir, filename) 91 | 92 | # test if int 93 | maxSpectDiff = cmdargs.maxSpectDiff 94 | if maxSpectDiff != 'auto': 95 | maxSpectDiff = int(maxSpectDiff) 96 | 97 | # run the segmentation on this tile. 98 | # save the result as a GTiff so do_stitch.py can open this tile 99 | # directly from S3. 100 | # TODO: create COG instead 101 | tiling.doTiledShepherdSegmentation_doOne(inDs, filename, 102 | dataFromPickle['tileInfo'], col, row, dataFromPickle['bandNumbers'], 103 | dataFromPickle['imgNullVal'], dataFromPickle['kmeansObj'], 104 | minSegmentSize=cmdargs.minSegmentSize, 105 | spectDistPcntile=cmdargs.spectDistPcntile, maxSpectralDiff=maxSpectDiff, 106 | tempfilesDriver='GTiff', tempfilesCreationOptions=['COMPRESS=DEFLATE', 107 | 'ZLEVEL=1', 'PREDICTOR=2', 'TILED=YES', 'INTERLEAVE=BAND', 108 | 'BIGTIFF=NO', 'BLOCKXSIZE=512', 'BLOCKYSIZE=512']) 109 | 110 | # upload the tile to S3. 111 | s3.upload_file(filename, cmdargs.bucket, os.path.basename(filename)) 112 | 113 | # cleanup 114 | shutil.rmtree(tempDir) 115 | maxMem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss 116 | print('Max Mem Usage', maxMem) 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /parallel_examples/awsbatch/modify-stack.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ -z "${AWS_REGION}" ]]; then 4 | echo "Must set AWS_REGION first" 5 | exit 1 6 | fi 7 | 8 | aws cloudformation update-stack --stack-name pyshepseg-parallel \ 9 | --template-body file://template/template.yaml \ 10 | --capabilities CAPABILITY_NAMED_IAM --region $AWS_REGION \ 11 | --tags Key=PyShepSeg,Value=1 12 | -------------------------------------------------------------------------------- /parallel_examples/awsbatch/submit-pyshepseg-job.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Script that starts a tiled segmentation using AWS Batch. 5 | 6 | Submits a job that runs that runs the do_prepare.py 7 | script with the given arguments. 8 | 9 | do_prepare.py then submits an array job to run do_tile.py 10 | - one job per tile. It also submits a do_stitch.py job that 11 | is dependent on all the do_tile.py jobs finishing. 12 | """ 13 | 14 | import argparse 15 | import boto3 16 | 17 | 18 | def getCmdargs(): 19 | p = argparse.ArgumentParser() 20 | p.add_argument("--bucket", required=True, 21 | help="S3 Bucket to use") 22 | p.add_argument("--infile", required=True, 23 | help="Path in --bucket to use as input file") 24 | p.add_argument("--outfile", required=True, 25 | help="Path in --bucket to use as output file (.kea)") 26 | p.add_argument("--tileprefix", 27 | help="Unique prefix to save the output tiles with.") 28 | p.add_argument("-b", "--bands", 29 | help="Comma seperated list of bands to use. 1-based. Uses all bands by default.") 30 | p.add_argument("--jobqueue", default="PyShepSegBatchProcessingJobQueue", 31 | help="Name of Job Queue to use. (default=%(default)s)") 32 | p.add_argument("--jobdefnprepare", default="PyShepSegBatchJobDefinitionTile", 33 | help="Name of Job Definition to use for the preparation job. (default=%(default)s)") 34 | p.add_argument("--jobdefntile", default="PyShepSegBatchJobDefinitionTile", 35 | help="Name of Job Definition to use for tile jobs. (default=%(default)s)") 36 | p.add_argument("--jobdefnstitch", default="PyShepSegBatchJobDefinitionStitch", 37 | help="Name of Job Definition to use for the stitch job. (default=%(default)s)") 38 | p.add_argument("--region", default="us-west-2", 39 | help="Region to run the jobs in. (default=%(default)s)") 40 | p.add_argument("--tilesize", default=4096, type=int, 41 | help="Tile Size to use. (default=%(default)s)") 42 | p.add_argument("--overlapsize", default=1024, type=int, 43 | help="Tile Overlap to use. (default=%(default)s)") 44 | p.add_argument("--stats", help="path to json file specifying stats in format:" + 45 | "bucket:path/in/bucket.json. Contents must be a list of [img, band, " + 46 | "statsSelection] tuples.") 47 | p.add_argument("--spatialstats", help="path to json file specifying spatial " + 48 | "stats in format: bucket:path/in/bucket.jso. Contents must be a list of " + 49 | "[img, band, [list of (colName, colType) tuples], name-of-userfunc, param]" + 50 | " tuples.") 51 | p.add_argument("--nogdalstats", action="store_true", default=False, 52 | help="don't calculate GDAL's statistics or write a colour table. " + 53 | "Can't be used with --stats or --spatialstats.") 54 | p.add_argument("--minSegmentSize", type=int, default=50, required=False, 55 | help="Segment size for segmentation (default=%(default)s)") 56 | p.add_argument("--numClusters", type=int, default=60, required=False, 57 | help="Number of clusters for segmentation (default=%(default)s)") 58 | p.add_argument("--maxSpectDiff", required=False, default='auto', 59 | help="Maximum spectral difference for segmentation (default=%(default)s)") 60 | p.add_argument("--spectDistPcntile", type=int, default=50, required=False, 61 | help="Spectral Distance Percentile for segmentation (default=%(default)s)") 62 | p.add_argument("--noremove", action="store_true", default=False, 63 | help="don't remove files from S3 (for debugging)") 64 | p.add_argument("--statsreadworkers", type=int, default=0, 65 | help="Number or RIOS readworkers to use while calculating stats. " + 66 | "(default=%(default)s)") 67 | 68 | cmdargs = p.parse_args() 69 | 70 | return cmdargs 71 | 72 | 73 | def main(): 74 | cmdargs = getCmdargs() 75 | 76 | batch = boto3.client('batch', region_name=cmdargs.region) 77 | 78 | pickleName = 'pyshepseg_tiling.pkl' 79 | # make unique also if tiles are 80 | if cmdargs.tileprefix is not None: 81 | pickleName = '{}_pyshepseg_tiling.pkl'.format(cmdargs.tileprefix) 82 | 83 | cmd = ['/usr/bin/python3', '/ubarscsw/bin/do_prepare.py', 84 | '--region', cmdargs.region, 85 | '--bucket', cmdargs.bucket, '--pickle', pickleName, 86 | '--infile', cmdargs.infile, '--outfile', cmdargs.outfile, 87 | '--tilesize', str(cmdargs.tilesize), 88 | '--overlapsize', str(cmdargs.overlapsize), 89 | '--jobqueue', cmdargs.jobqueue, 90 | '--jobdefntile', cmdargs.jobdefntile, 91 | '--jobdefnstitch', cmdargs.jobdefnstitch, 92 | '--minSegmentSize', str(cmdargs.minSegmentSize), 93 | '--numClusters', str(cmdargs.numClusters), 94 | '--maxSpectDiff', cmdargs.maxSpectDiff, 95 | '--spectDistPcntile', str(cmdargs.spectDistPcntile), 96 | '--statsreadworkers', str(cmdargs.statsreadworkers)] 97 | if cmdargs.bands is not None: 98 | cmd.extend(['--bands', cmdargs.bands]) 99 | if cmdargs.stats is not None: 100 | cmd.extend(['--stats', cmdargs.stats]) 101 | if cmdargs.spatialstats is not None: 102 | cmd.extend(['--spatialstats', cmdargs.spatialstats]) 103 | if cmdargs.nogdalstats: 104 | cmd.append('--nogdalstats') 105 | if cmdargs.tileprefix is not None: 106 | cmd.extend(['--tileprefix', cmdargs.tileprefix]) 107 | if cmdargs.noremove: 108 | cmd.append('--noremove') 109 | 110 | # submit the prepare job 111 | response = batch.submit_job(jobName="pyshepseg_prepare", 112 | jobQueue=cmdargs.jobqueue, 113 | jobDefinition=cmdargs.jobdefntile, 114 | containerOverrides={ 115 | "command": cmd}) 116 | prepareId = response['jobId'] 117 | print('Prepare Job Id', prepareId) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /parallel_examples/awsbatch/template/template.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | AWSTemplateFormatVersion: '2010-09-09' 3 | Description: 'UBARSC AWS Batch Tiled Segmentation using CloudFormation' 4 | Resources: 5 | VPC: 6 | Type: AWS::EC2::VPC 7 | Properties: 8 | CidrBlock: 10.0.0.0/16 9 | InternetGateway: 10 | Type: AWS::EC2::InternetGateway 11 | RouteTable: 12 | Type: AWS::EC2::RouteTable 13 | Properties: 14 | VpcId: 15 | Ref: VPC 16 | VPCGatewayAttachment: 17 | Type: AWS::EC2::VPCGatewayAttachment 18 | Properties: 19 | VpcId: 20 | Ref: VPC 21 | InternetGatewayId: 22 | Ref: InternetGateway 23 | SecurityGroup: 24 | Type: AWS::EC2::SecurityGroup 25 | Properties: 26 | GroupDescription: EC2 Security Group for instances launched in the VPC by Batch 27 | VpcId: 28 | Ref: VPC 29 | Subnet1: 30 | Type: AWS::EC2::Subnet 31 | Properties: 32 | CidrBlock: 10.0.0.0/24 33 | VpcId: 34 | Ref: VPC 35 | # yes we do need public ips or NAT 36 | # See https://repost.aws/knowledge-center/batch-job-stuck-runnable-status 37 | MapPublicIpOnLaunch: 'True' 38 | AvailabilityZone: !Select 39 | - 0 40 | - Fn::GetAZs: !Ref 'AWS::Region' 41 | Subnet2: 42 | Type: AWS::EC2::Subnet 43 | Properties: 44 | CidrBlock: 10.0.1.0/24 45 | VpcId: 46 | Ref: VPC 47 | MapPublicIpOnLaunch: 'True' 48 | AvailabilityZone: !Select 49 | - 1 50 | - Fn::GetAZs: !Ref 'AWS::Region' 51 | Subnet3: 52 | Type: AWS::EC2::Subnet 53 | Properties: 54 | CidrBlock: 10.0.2.0/24 55 | VpcId: 56 | Ref: VPC 57 | MapPublicIpOnLaunch: 'True' 58 | AvailabilityZone: !Select 59 | - 2 60 | - Fn::GetAZs: !Ref 'AWS::Region' 61 | Route: 62 | Type: AWS::EC2::Route 63 | Properties: 64 | RouteTableId: 65 | Ref: RouteTable 66 | DestinationCidrBlock: 0.0.0.0/0 67 | GatewayId: 68 | Ref: InternetGateway 69 | # Allow S3 traffic to go through an internet gateway 70 | S3GatewayEndpoint: 71 | Type: AWS::EC2::VPCEndpoint 72 | Properties: 73 | VpcEndpointType: 'Gateway' 74 | VpcId: !Ref VPC 75 | ServiceName: !Sub 'com.amazonaws.${AWS::Region}.s3' 76 | PolicyDocument: 77 | Version: 2012-10-17 78 | Statement: 79 | - Effect: Allow 80 | Principal: '*' 81 | Action: 82 | - 's3:*Object' 83 | - 's3:ListBucket' 84 | Resource: 85 | - 'arn:aws:s3:::*/*' 86 | - 'arn:aws:s3:::*' 87 | RouteTableIds: 88 | - !Ref RouteTable 89 | SubnetRouteTableAssociation1: 90 | Type: AWS::EC2::SubnetRouteTableAssociation 91 | Properties: 92 | RouteTableId: 93 | Ref: RouteTable 94 | SubnetId: 95 | Ref: Subnet1 96 | SubnetRouteTableAssociation2: 97 | Type: AWS::EC2::SubnetRouteTableAssociation 98 | Properties: 99 | RouteTableId: 100 | Ref: RouteTable 101 | SubnetId: 102 | Ref: Subnet2 103 | SubnetRouteTableAssociation3: 104 | Type: AWS::EC2::SubnetRouteTableAssociation 105 | Properties: 106 | RouteTableId: 107 | Ref: RouteTable 108 | SubnetId: 109 | Ref: Subnet3 110 | IamInstanceProfile: 111 | Type: AWS::IAM::InstanceProfile 112 | Properties: 113 | Roles: 114 | - Ref: EcsInstanceRole 115 | SubmitJobsManagedPolicy: 116 | Type: AWS::IAM::ManagedPolicy 117 | Properties: 118 | Description: Policy for allowing web site to submit Batch jobs 119 | Path: / 120 | PolicyDocument: 121 | Version: '2012-10-17' 122 | Statement: 123 | - Effect: Allow 124 | Action: 'batch:SubmitJob' 125 | Resource: 126 | # can't use resources directly as causes a circular dependency 127 | # but we know what the names will be 128 | #- !Ref BatchProcessingJobDefinitionTile 129 | #- !Ref BatchProcessingJobDefinitionStitch 130 | #- !Ref BatchProcessingJobQueue 131 | - !Sub 'arn:aws:batch:${AWS::Region}:${AWS::AccountId}:job-definition/PyShepSegBatchJobDefinitionTile:*' 132 | - !Sub 'arn:aws:batch:${AWS::Region}:${AWS::AccountId}:job-definition/PyShepSegBatchJobDefinitionStitch:*' 133 | - !Sub 'arn:aws:batch:${AWS::Region}:${AWS::AccountId}:job-queue/PyShepSegBatchProcessingJobQueue' 134 | EcsInstanceRole: 135 | Type: AWS::IAM::Role 136 | Properties: 137 | AssumeRolePolicyDocument: 138 | Version: '2008-10-17' 139 | Statement: 140 | - Sid: '' 141 | Effect: Allow 142 | Principal: 143 | Service: ec2.amazonaws.com 144 | Action: sts:AssumeRole 145 | ManagedPolicyArns: 146 | - arn:aws:iam::aws:policy/service-role/AmazonEC2ContainerServiceforEC2Role 147 | - arn:aws:iam::aws:policy/AmazonS3FullAccess 148 | - arn:aws:iam::aws:policy/service-role/AWSBatchServiceEventTargetRole 149 | - !Ref SubmitJobsManagedPolicy 150 | BatchRepository: 151 | Type: AWS::ECR::Repository 152 | Properties: 153 | RepositoryName: pyshepseg 154 | LifecyclePolicy: 155 | LifecyclePolicyText: | 156 | { 157 | "rules": [ 158 | { 159 | "rulePriority": 1, 160 | "description": "Expire images older than 1 day", 161 | "selection": { 162 | "tagStatus": "untagged", 163 | "countType": "sinceImagePushed", 164 | "countUnit": "days", 165 | "countNumber": 1 166 | }, 167 | "action": { 168 | "type": "expire" 169 | } 170 | } 171 | ] 172 | } 173 | # Job for doing individual tiles 174 | BatchProcessingJobDefinitionTile: 175 | Type: AWS::Batch::JobDefinition 176 | Properties: 177 | Type: container 178 | JobDefinitionName: PyShepSegBatchJobDefinitionTile 179 | ContainerProperties: 180 | Image: !Join ['', [!GetAtt BatchRepository.RepositoryUri, ":latest"]] 181 | Vcpus: 1 182 | Memory: 16000 183 | RetryStrategy: 184 | Attempts: 1 185 | # Job for stitching tiles together 186 | BatchProcessingJobDefinitionStitch: 187 | Type: AWS::Batch::JobDefinition 188 | Properties: 189 | Type: container 190 | JobDefinitionName: PyShepSegBatchJobDefinitionStitch 191 | ContainerProperties: 192 | Image: !Join ['', [!GetAtt BatchRepository.RepositoryUri, ":latest"]] 193 | Vcpus: 4 194 | Memory: 12000 195 | RetryStrategy: 196 | Attempts: 1 197 | BatchProcessingJobQueue: 198 | Type: AWS::Batch::JobQueue 199 | Properties: 200 | JobQueueName: PyShepSegBatchProcessingJobQueue 201 | Priority: 1 202 | ComputeEnvironmentOrder: 203 | - Order: 1 204 | ComputeEnvironment: 205 | Ref: ComputeEnvironment 206 | ComputeEnvironment: 207 | Type: AWS::Batch::ComputeEnvironment 208 | Properties: 209 | Type: MANAGED 210 | ComputeResources: 211 | Type: EC2 212 | MinvCpus: 0 213 | DesiredvCpus: 0 214 | MaxvCpus: 1024 215 | InstanceTypes: 216 | #- a1.medium 217 | - optimal 218 | Subnets: 219 | - Ref: Subnet1 220 | - Ref: Subnet2 221 | - Ref: Subnet3 222 | SecurityGroupIds: 223 | - Ref: SecurityGroup 224 | InstanceRole: 225 | Ref: IamInstanceProfile 226 | LaunchTemplate: 227 | LaunchTemplateId: !Ref LaunchTemplate 228 | Version: !GetAtt LaunchTemplate.LatestVersionNumber 229 | # Launch template - increase default storage available 230 | # https://repost.aws/knowledge-center/batch-job-failure-disk-space 231 | # https://docs.aws.amazon.com/batch/latest/userguide/launch-templates.html 232 | # Probably don't need this for all job types, but that would mean different queues 233 | LaunchTemplate: 234 | Type: AWS::EC2::LaunchTemplate 235 | Properties: 236 | LaunchTemplateData: 237 | BlockDeviceMappings: 238 | - DeviceName: /dev/xvda 239 | Ebs: 240 | VolumeType: gp2 241 | VolumeSize: 1024 242 | DeleteOnTermination: true 243 | 244 | Outputs: 245 | ComputeEnvironmentArn: 246 | Value: 247 | Ref: ComputeEnvironment 248 | BatchProcessingJobQueueArn: 249 | Value: 250 | Ref: BatchProcessingJobQueue 251 | BatchProcessingJobDefinitionArn: 252 | Value: 253 | Ref: BatchProcessingJobDefinitionTile 254 | Ref: BatchProcessingJobDefinitionStitch 255 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # The installation requires pip>=23.0. If not, it will fail in rather 2 | # cryptic ways (depending exactly what options are used). 3 | # 4 | 5 | # We have chosen not to explicitly list the dependency on GDAL. This is 6 | # because GDAL itself cannot be installed with pip, and so must already be 7 | # installed on the system by some other means. 8 | 9 | [build-system] 10 | requires = ["setuptools>=61.0", "wheel"] 11 | build-backend = "setuptools.build_meta" 12 | 13 | [project] 14 | name = "pyshepseg" 15 | dynamic = ["version"] 16 | authors = [ 17 | {name = "Sam Gillingham"}, 18 | {name = "Neil Flood"} 19 | ] 20 | description = "Python implementation of the image segmentation algorithm described by Shepherd et al" 21 | readme = "README.md" 22 | license = {file = "LICENSE.txt"} 23 | 24 | dependencies = [ 25 | "numba", 26 | "scikit-learn" 27 | ] 28 | 29 | [project.scripts] 30 | pyshepseg_run_seg = "pyshepseg.cmdline.run_seg:main" 31 | pyshepseg_tiling = "pyshepseg.cmdline.tiling:main" 32 | pyshepseg_subset = "pyshepseg.cmdline.subset:main" 33 | pyshepseg_runtests = "pyshepseg.cmdline.runtests:main" 34 | pyshepseg_variograms = "pyshepseg.cmdline.variograms:main" 35 | pyshepseg_segmentationworkercmd = "pyshepseg.cmdline.pyshepseg_segmentationworkercmd:mainCmd" 36 | 37 | [tool.setuptools] 38 | packages = ["pyshepseg", "pyshepseg.cmdline"] 39 | 40 | [tool.setuptools.dynamic] 41 | version = {attr = "pyshepseg.__version__"} 42 | 43 | [project.urls] 44 | Repository = "https://github.com/ubarsc/pyshepseg.git" 45 | Homepage = "https://www.pyshepseg.org" 46 | -------------------------------------------------------------------------------- /pyshepseg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Neil Flood and Sam Gillingham. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person 4 | # obtaining a copy of this software and associated documentation 5 | # files (the "Software"), to deal in the Software without restriction, 6 | # including without limitation the rights to use, copy, modify, 7 | # merge, publish, distribute, sublicense, and/or sell copies of the 8 | # Software, and to permit persons to whom the Software is furnished 9 | # to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be 12 | # included in all copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 16 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 17 | # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR 18 | # ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 19 | # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | 22 | """ 23 | Python implementation of image segmentation algorithm by 24 | Shepherd et al (2019). 25 | 26 | The main algorithm is implemented in the pyshepseg.shepseg module. 27 | For memory-efficient tiled segmentation of larger rasters, 28 | see the pyshepseg.tiling module. 29 | 30 | """ 31 | 32 | SHEPSEG_VERSION = '2.0.3' 33 | __version__ = SHEPSEG_VERSION 34 | 35 | -------------------------------------------------------------------------------- /pyshepseg/cmdline/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sub package with modules that are intended for calling from the command line 3 | """ 4 | -------------------------------------------------------------------------------- /pyshepseg/cmdline/pyshepseg_segmentationworkercmd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Main script for a segmentation worker running in a separate process. 4 | """ 5 | import argparse 6 | import queue 7 | 8 | import numpy 9 | from osgeo import gdal 10 | 11 | from pyshepseg import shepseg 12 | from pyshepseg.tiling import NetworkDataChannel 13 | from pyshepseg.utils import WorkerErrorRecord 14 | from pyshepseg.timinghooks import Timers 15 | 16 | 17 | # Compute workers in separate processes should always use GDAL exceptions, 18 | # regardless of whether the main script is doing so. 19 | gdal.UseExceptions() 20 | 21 | 22 | def getCmdargs(): 23 | """ 24 | Get command line arguments 25 | """ 26 | p = argparse.ArgumentParser(description=("Main script run by each " + 27 | "segmentation worker")) 28 | p.add_argument("-i", "--idnum", type=int, help="Worker ID number") 29 | p.add_argument("--channaddrfile", help="File with data channel address") 30 | p.add_argument("--channaddr", help=("Directly specified data channel " + 31 | "address, as 'hostname,portnum,authkey'. This is less secure, and " + 32 | "should only be used if the preferred option --channaddrfile " + 33 | "cannot be used")) 34 | 35 | cmdargs = p.parse_args() 36 | return cmdargs 37 | 38 | 39 | def mainCmd(): 40 | """ 41 | Main entry point for command script. This is referenced by the install 42 | configuration to generate the actual command line main script. 43 | """ 44 | cmdargs = getCmdargs() 45 | 46 | if cmdargs.channaddrfile is not None: 47 | addrStr = open(cmdargs.channaddrfile).readline().strip() 48 | else: 49 | addrStr = cmdargs.channaddr 50 | 51 | (host, port, authkey) = tuple(addrStr.split(',')) 52 | port = int(port) 53 | authkey = bytes(authkey, 'utf-8') 54 | 55 | pyshepsegRemoteSegmentationWorker(cmdargs.idnum, host, port, authkey) 56 | 57 | 58 | def pyshepsegRemoteSegmentationWorker(workerID, host, port, authkey): 59 | """ 60 | The main routine to run a segmentation worker on a remote host. 61 | 62 | """ 63 | dataChan = NetworkDataChannel(hostname=host, portnum=port, authkey=authkey) 64 | 65 | try: 66 | infile = dataChan.segDataDict.get('infile') 67 | tileInfo = dataChan.segDataDict.get('tileInfo') 68 | minSegmentSize = dataChan.segDataDict.get('minSegmentSize') 69 | maxSpectralDiff = dataChan.segDataDict.get('maxSpectralDiff') 70 | imgNullVal = dataChan.segDataDict.get('imgNullVal') 71 | fourConnected = dataChan.segDataDict.get('fourConnected') 72 | kmeansObj = dataChan.segDataDict.get('kmeansObj') 73 | verbose = dataChan.segDataDict.get('verbose') 74 | spectDistPcntile = dataChan.segDataDict.get('spectDistPcntile') 75 | bandNumbers = dataChan.segDataDict.get('bandNumbers') 76 | 77 | barrierTimeout = dataChan.segDataDict.get('barrierTimeout') 78 | workerBarrier = dataChan.workerBarrier 79 | if hasattr(workerBarrier, 'wait'): 80 | workerBarrier.wait(timeout=barrierTimeout) 81 | 82 | # Use our own local timings object, because the proxy one does not support 83 | # the context manager protocol 84 | timings = Timers() 85 | 86 | inDs = gdal.Open(infile) 87 | 88 | colRow = popFromQue(dataChan.inQue) 89 | while colRow is not None: 90 | (col, row) = colRow 91 | 92 | xpos, ypos, xsize, ysize = tileInfo.getTile(col, row) 93 | 94 | with timings.interval('reading'): 95 | lyrDataList = [] 96 | for bandNum in bandNumbers: 97 | # Note that the proxy semaphore object does not support 98 | # context manager protocol, so we use acquire/release 99 | dataChan.readSemaphore.acquire() 100 | lyr = inDs.GetRasterBand(bandNum) 101 | lyrData = lyr.ReadAsArray(xpos, ypos, xsize, ysize) 102 | lyrDataList.append(lyrData) 103 | dataChan.readSemaphore.release() 104 | 105 | img = numpy.array(lyrDataList) 106 | 107 | with timings.interval('segmentation'): 108 | segResult = shepseg.doShepherdSegmentation(img, 109 | minSegmentSize=minSegmentSize, 110 | maxSpectralDiff=maxSpectralDiff, 111 | imgNullVal=imgNullVal, 112 | fourConnected=fourConnected, 113 | kmeansObj=kmeansObj, 114 | verbose=verbose, 115 | spectDistPcntile=spectDistPcntile) 116 | 117 | dataChan.segResultCache.addResult(col, row, segResult) 118 | colRow = popFromQue(dataChan.inQue) 119 | 120 | # Merge the local timings object with the central one. 121 | dataChan.timings.merge(timings) 122 | except Exception as e: 123 | # Send a printable version of the exception back to main thread 124 | workerErr = WorkerErrorRecord(e, 'compute') 125 | dataChan.exceptionQue.put(workerErr) 126 | 127 | 128 | def popFromQue(que): 129 | """ 130 | Pop out the next item from the given Queue, returning None if 131 | the queue is empty. 132 | 133 | WARNING: don't use this if the queued items can be None 134 | """ 135 | try: 136 | item = que.get(block=False) 137 | except queue.Empty: 138 | item = None 139 | return item 140 | 141 | 142 | if __name__ == "__main__": 143 | mainCmd() 144 | -------------------------------------------------------------------------------- /pyshepseg/cmdline/run_seg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Testing harness for pyshepseg. Handy for running a basic segmentation 5 | but it is suggested that users call the module directly from a Python 6 | script and handle things like scaling the data in an appripriate 7 | manner for their application. 8 | 9 | """ 10 | # Copyright 2021 Neil Flood and Sam Gillingham. All rights reserved. 11 | # 12 | # Permission is hereby granted, free of charge, to any person 13 | # obtaining a copy of this software and associated documentation 14 | # files (the "Software"), to deal in the Software without restriction, 15 | # including without limitation the rights to use, copy, modify, 16 | # merge, publish, distribute, sublicense, and/or sell copies of the 17 | # Software, and to permit persons to whom the Software is furnished 18 | # to do so, subject to the following conditions: 19 | # 20 | # The above copyright notice and this permission notice shall be 21 | # included in all copies or substantial portions of the Software. 22 | # 23 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 24 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 25 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 26 | # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR 27 | # ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 28 | # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 29 | # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 30 | 31 | from __future__ import print_function, division 32 | 33 | import os 34 | import sys 35 | import json 36 | import argparse 37 | import time 38 | 39 | import numpy 40 | from osgeo import gdal 41 | 42 | from pyshepseg import shepseg 43 | from pyshepseg import utils 44 | 45 | DFLT_OUTPUT_DRIVER = 'KEA' 46 | GDAL_DRIVER_CREATION_OPTIONS = {'KEA': [], 'HFA': ['COMPRESS=YES']} 47 | 48 | DFLT_MAX_SPECTRAL_DIFF = 'auto' 49 | 50 | CLUSTER_CNTRS_METADATA_NAME = 'pyshepseg_cluster_cntrs' 51 | 52 | 53 | def getCmdargs(): 54 | """ 55 | Get the command line arguments. 56 | """ 57 | p = argparse.ArgumentParser() 58 | p.add_argument("-i", "--infile", help="Input Raster file") 59 | p.add_argument("-o", "--outfile") 60 | p.add_argument("-n", "--nclusters", default=60, type=int, 61 | help="Number of clusters (default=%(default)s)") 62 | p.add_argument("--subsamplepcnt", type=int, default=1, 63 | help="Percentage to subsample for fitting (default=%(default)s)") 64 | p.add_argument("--eightway", default=False, action="store_true", 65 | help="Use 8-way instead of 4-way") 66 | p.add_argument("-f", "--format", default=DFLT_OUTPUT_DRIVER, 67 | choices=[DFLT_OUTPUT_DRIVER, "HFA"], 68 | help="Name of output GDAL format that supports RATs (default=%(default)s)") 69 | p.add_argument("-m", "--maxspectraldiff", default=DFLT_MAX_SPECTRAL_DIFF, 70 | help=("Maximum Spectral Difference to use when merging " + 71 | "segments Either 'auto', 'none' or a value to use " + 72 | "(default=%(default)s)")) 73 | p.add_argument("-s", "--minsegmentsize", default=100, type=int, 74 | help="Minimum segment size in pixels (default=%(default)s)") 75 | p.add_argument("-c", "--clustersubsamplepercent", default=0.5, type=float, 76 | help="Percent of data to subsample for clustering (default=%(default)s)") 77 | p.add_argument("-b", "--bands", default="3,4,5", 78 | help="Comma seperated list of bands to use. 1-based. (default=%(default)s)") 79 | p.add_argument("--fixedkmeansinit", default=False, action="store_true", 80 | help=("Used a fixed algorithm to select guesses at cluster centres. "+ 81 | "Default will allow the KMeans routine to select these with a "+ 82 | "random element, which can make the final results slightly "+ 83 | "non-determinstic. Use this if you want a completely "+ 84 | "deterministic, reproducable result")) 85 | 86 | cmdargs = p.parse_args() 87 | 88 | if cmdargs.infile is None: 89 | print('Must supply input file name') 90 | p.print_help() 91 | sys.exit() 92 | 93 | if cmdargs.outfile is None: 94 | print('Must supply output file name') 95 | p.print_help() 96 | sys.exit() 97 | 98 | try: 99 | cmdargs.maxspectraldiff = float(cmdargs.maxspectraldiff) 100 | except ValueError: 101 | # check for 'auto' or 'none' 102 | if cmdargs.maxspectraldiff not in ('auto', 'none'): 103 | print("Only 'auto', 'none' or a value supported for --maxspectraldiff") 104 | p.print_help() 105 | sys.exit() 106 | 107 | # code expects 'none' -> None 108 | if cmdargs.maxspectraldiff == 'none': 109 | cmdargs.maxspectraldiff = None 110 | 111 | # turn string of bands into list of ints 112 | cmdargs.bands = [int(x) for x in cmdargs.bands.split(',')] 113 | 114 | return cmdargs 115 | 116 | 117 | def main(): 118 | cmdargs = getCmdargs() 119 | 120 | t0 = time.time() 121 | print("Reading ... ", end='') 122 | (img, refNull) = readImageBands(cmdargs) 123 | print(round(time.time() - t0, 1), "seconds") 124 | 125 | # Do the segmentation 126 | segResult = shepseg.doShepherdSegmentation(img, 127 | numClusters=cmdargs.nclusters, 128 | clusterSubsamplePcnt=cmdargs.clustersubsamplepercent, 129 | minSegmentSize=cmdargs.minsegmentsize, 130 | maxSpectralDiff=cmdargs.maxspectraldiff, 131 | imgNullVal=refNull, fourConnected=not cmdargs.eightway, 132 | fixedKMeansInit=cmdargs.fixedkmeansinit, verbose=True) 133 | 134 | # The segmentation image, and a few related quantities 135 | seg = segResult.segimg 136 | segSize = shepseg.makeSegSize(seg) 137 | maxSegId = seg.max() 138 | spectSum = shepseg.buildSegmentSpectra(seg, img, maxSegId) 139 | kmeansObj = segResult.kmeans 140 | 141 | writeOutput(cmdargs, seg, segSize, spectSum, kmeansObj) 142 | 143 | 144 | def writeOutput(cmdargs, seg, segSize, spectSum, kmeansObj): 145 | """ 146 | Write the segmentation to an output image file. Includes a 147 | colour table 148 | """ 149 | # Write output 150 | outType = gdal.GDT_UInt32 151 | 152 | (nRows, nCols) = seg.shape 153 | outDrvr = gdal.GetDriverByName(cmdargs.format) 154 | if outDrvr is None: 155 | msg = 'This GDAL does not support driver {}'.format(cmdargs.format) 156 | raise SystemExit(msg) 157 | 158 | if os.path.exists(cmdargs.outfile): 159 | outDrvr.Delete(cmdargs.outfile) 160 | 161 | creationOptions = GDAL_DRIVER_CREATION_OPTIONS[cmdargs.format] 162 | 163 | inDs = gdal.Open(cmdargs.infile) 164 | 165 | outDs = outDrvr.Create(cmdargs.outfile, nCols, nRows, 1, outType, 166 | options=creationOptions) 167 | outDs.SetProjection(inDs.GetProjection()) 168 | outDs.SetGeoTransform(inDs.GetGeoTransform()) 169 | b = outDs.GetRasterBand(1) 170 | b.WriteArray(seg) 171 | b.SetMetadataItem('LAYER_TYPE', 'thematic') 172 | b.SetNoDataValue(shepseg.SEGNULLVAL) 173 | 174 | # since we have the histo we can do the stats 175 | utils.estimateStatsFromHisto(b, segSize) 176 | 177 | # overviews 178 | utils.addOverviews(outDs) 179 | 180 | # save the cluster centres 181 | writeClusterCentresToMetadata(b, kmeansObj) 182 | 183 | del outDs 184 | 185 | 186 | def readImageBands(cmdargs): 187 | """ 188 | Read in the requested bands of the given image. Return 189 | a tuple of the image array and the null value. 190 | """ 191 | ds = gdal.Open(cmdargs.infile) 192 | bandList = [] 193 | for bn in cmdargs.bands: 194 | b = ds.GetRasterBand(bn) 195 | refNull = b.GetNoDataValue() 196 | a = b.ReadAsArray() 197 | bandList.append(a) 198 | img = numpy.array(bandList) 199 | 200 | return (img, refNull) 201 | 202 | 203 | def writeClusterCentresToMetadata(bandObj, km): 204 | """ 205 | Pulls out the cluster centres from the kmeans object 206 | and writes them to the metadata (under CLUSTER_CNTRS_METADATA_NAME) 207 | for the given band object. 208 | """ 209 | # convert to list so we can json them 210 | ctrsList = [list(r) for r in km.cluster_centers_] 211 | ctrsString = json.dumps(ctrsList) 212 | 213 | bandObj.SetMetadataItem(CLUSTER_CNTRS_METADATA_NAME, ctrsString) 214 | 215 | 216 | if __name__ == "__main__": 217 | main() 218 | -------------------------------------------------------------------------------- /pyshepseg/cmdline/runtests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Run tests of the pyshepseg package. Use a generated dataset, sufficient 4 | to allow a meaningful segmentation. Note that the generated dataset is not 5 | sufficiently complex to be a strong test of the Shepherd segmentation 6 | algorithm, but merely whether this implementation of the algorithm is 7 | coded to behave sensibly. 8 | 9 | """ 10 | import sys 11 | import os 12 | import argparse 13 | 14 | import numpy 15 | 16 | from osgeo import gdal 17 | 18 | from pyshepseg import shepseg, tiling, tilingstats, utils, subset 19 | 20 | # This is a list of (x, y) coordinate pairs, representing centres of 21 | # some test segments. These were initially generated randomly, but 22 | # are saved here so we can use them reproduceably for our testing. 23 | # Note that (x, y) will also be used as (column, row) in the 24 | # image array, because there is no point in creating a separate 25 | # world coordinate system (it would make no difference to the things 26 | # being tested). 27 | initialCentres = [(116, 3495), (142, 3100), (236, 6033), (290, 796), (297, 6152), (310, 5318), (409, 5867), (410, 2125), (442, 2913), 28 | (472, 1135), (486, 5296), (628, 667), (655, 2677), (672, 4001), (677, 5513), (736, 3720), (913, 3552), (1056, 347), (1085, 3391), 29 | (1121, 6623), (1150, 1906), (1196, 5663), (1694, 3244), (1761, 2172), (1761, 7460), (1882, 6151), (1893, 626), (2014, 433), (2065, 3157), 30 | (2132, 378), (2161, 2352), (2200, 7485), (2393, 5191), (2489, 2519), (2508, 1575), (2509, 7089), (2599, 3151), (2645, 2672), (2782, 3380), 31 | (2906, 3676), (3072, 2934), (3133, 3418), (3188, 1653), (3624, 7812), (3661, 3603), (3694, 2929), (3759, 3418), (4155, 630), (4233, 4753), 32 | (4423, 1377), (4427, 6635), (4462, 7392), (4715, 6908), (4856, 2559), (4898, 3371), (5051, 2268), (5064, 5969), (5071, 2019), (5107, 3533), 33 | (5172, 5478), (5294, 4210), (5305, 1512), (5310, 2846), (5365, 3715), (5447, 6215), (5513, 5017), (5549, 297), (5579, 4076), (5623, 5044), 34 | (5688, 3614), (5728, 1802), (5747, 7801), (5758, 4377), (5779, 4148), (5784, 3239), (5812, 5091), (5862, 4664), (5897, 4963), (6299, 4702), 35 | (6320, 6936), (6462, 2844), (6615, 4979), (6726, 5970), (6754, 7652), (6765, 714), (6826, 3162), (6827, 3770), (6844, 1170), (6884, 226), 36 | (7023, 213), (7094, 6472), (7157, 647), (7196, 7710), (7293, 7588), (7495, 5912), (7693, 3966), (7718, 7759), (7737, 6002), (7745, 1347), 37 | (7889, 2850)] 38 | 39 | # Shape of the image we will be working with. 40 | (nRows, nCols) = (8000, 8000) 41 | # Number of bands in constructed multispectral image 42 | NBANDS = 3 43 | 44 | 45 | def getCmdargs(): 46 | """ 47 | Get command line arguments 48 | """ 49 | p = argparse.ArgumentParser(description=""" 50 | Run tests of the software, using internally generated data. 51 | This is mainly intended for checking whether code changes 52 | have broken anything. It is not a rigorous test of the 53 | Shepherd algorithm. 54 | """) 55 | p.add_argument("--keep", default=False, action="store_true", 56 | help="Keep test data files (default will delete)") 57 | p.add_argument("--knownseg", help=("Use this file as the initial "+ 58 | "known segmentation, instead of generating one. Helpful "+ 59 | " with --keep, to save time in repeated tests")) 60 | return p.parse_args() 61 | 62 | 63 | def main(): 64 | """ 65 | Main routine 66 | """ 67 | cmdargs = getCmdargs() 68 | 69 | errorStatus = 0 70 | 71 | truesegfile = "tmp_trueseg.kea" 72 | if cmdargs.knownseg is not None: 73 | truesegfile = cmdargs.knownseg 74 | imagefile = "tmp_image.kea" 75 | outsegfile = "tmp_seg.kea" 76 | subset_segfile = "tmp_seg_subset.kea" 77 | tmpdatafiles = [imagefile, outsegfile, subset_segfile] 78 | if cmdargs.knownseg is None: 79 | tmpdatafiles.append(truesegfile) 80 | 81 | if cmdargs.knownseg is None: 82 | print("Generating known segmentation {}".format(truesegfile)) 83 | generateTrueSegments(truesegfile) 84 | print("Creating known multi-spectral image {}".format(imagefile)) 85 | createMultispectral(truesegfile, imagefile) 86 | 87 | # Ensure enough clusters that we have a different cluster for each 88 | # segment in the generated image. We have not guarded against neighbours 89 | # being similar, so this will prevent neighbouring segments being 90 | # merged too easily. 91 | numClusters = len(initialCentres) 92 | 93 | print("Doing tiled segmentation, creating {}".format(outsegfile)) 94 | # Note that we use fourConnected=False, to avoid disconnecting the 95 | # pointy ends of long thin slivers, which can arise due to how we 96 | # generated the original segments. 97 | segResults = tiling.doTiledShepherdSegmentation(imagefile, outsegfile, 98 | numClusters=numClusters, fixedKMeansInit=True, fourConnected=False) 99 | 100 | # some columns that test the stats 101 | (meanColNames, stdColNames) = makeRATcolumns(segResults, outsegfile, imagefile) 102 | 103 | # some columns that test the spatial stats 104 | (eastingCol, northingCol) = makeSpatialRATColumns(outsegfile, imagefile) 105 | 106 | # check the segmentation via the non-spatial stats 107 | pcntMatch = checkSegmentation(imagefile, outsegfile, meanColNames, 108 | stdColNames) 109 | 110 | print("Segment match on {}% of pixels".format(pcntMatch)) 111 | if pcntMatch < 100: 112 | print("Matching on less than 100% suggests that the segmentation went wrong") 113 | errorStatus = 1 114 | 115 | # check the spatial cols 116 | if not checkSpatialColumns(outsegfile, eastingCol, northingCol): 117 | print('Mean coordinates of segments differ') 118 | errorStatus = 1 119 | 120 | print("Checking subset functionality") 121 | if not checkSubset(outsegfile, subset_segfile): 122 | print('Unable to match new values from subset') 123 | errorStatus = 1 124 | 125 | print("Adding colour table to {}".format(outsegfile)) 126 | utils.writeColorTableFromRatColumns(outsegfile, meanColNames[0], 127 | meanColNames[1], meanColNames[2]) 128 | 129 | if not cmdargs.keep: 130 | print("Removing generated data") 131 | for fn in tmpdatafiles: 132 | drvr = gdal.IdentifyDriver(fn) 133 | drvr.Delete(fn) 134 | 135 | # Exit with an explicit status code, so that Github workflow 136 | # can recognise if something went wrong. 137 | sys.exit(errorStatus) 138 | 139 | 140 | # The basis of the test data will be a set of "true" segments. From 141 | # these we will generate multi-spectral data which represents these 142 | # segments, and the tests will then use the pyshepseg package to 143 | # re-create the original segments from the multi-spectral data. 144 | 145 | def generateTrueSegments(outfile): 146 | """ 147 | This routine generates the true segments from the initial segment 148 | centres hardwired in the initialCentres variable. 149 | 150 | Each pixel is in the segment for its closest centre coordinate. 151 | 152 | Saves the generated segment layer into the given raster filename, 153 | with format KEA. 154 | 155 | """ 156 | segArray = numpy.zeros((nRows, nCols), dtype=shepseg.SegIdType) 157 | segArray.fill(shepseg.SEGNULLVAL) 158 | 159 | minDist = numpy.zeros((nRows, nCols), dtype=numpy.float32) 160 | # Initial distance much bigger than whole array, so actual centres 161 | # will all be closer 162 | minDist.fill(10 * nCols) 163 | 164 | # For each pixel, its (x, y) position, to use in calculating distances 165 | (xGrid, yGrid) = numpy.mgrid[:nRows, :nCols] 166 | 167 | numCentres = len(initialCentres) 168 | for i in range(numCentres): 169 | (x, y) = initialCentres[i] 170 | dist = numpy.sqrt((xGrid - x)**2 + (yGrid - y)**2) 171 | minNdx = (dist < minDist) 172 | 173 | segId = i + 1 174 | segArray[minNdx] = segId 175 | # For each pixel, update the distance to the closest centre 176 | minDist[minNdx] = dist[minNdx] 177 | 178 | # Put in a margin of nulls all round, so that we can also test 179 | # that null handling is working properly 180 | m = 10 181 | segArray[:m, :] = shepseg.SEGNULLVAL 182 | segArray[-m:, :] = shepseg.SEGNULLVAL 183 | segArray[:, :m] = shepseg.SEGNULLVAL 184 | segArray[:, -m:] = shepseg.SEGNULLVAL 185 | 186 | # Save to a KEA file 187 | drvr = gdal.GetDriverByName('KEA') 188 | if os.path.exists(outfile): 189 | drvr.Delete(outfile) 190 | ds = drvr.Create(outfile, nCols, nRows, bands=1, eType=gdal.GDT_UInt32) 191 | ds.SetGeoTransform((0, 1, 0, 0, 0, -1)) 192 | band = ds.GetRasterBand(1) 193 | band.SetNoDataValue(shepseg.SEGNULLVAL) 194 | band.WriteArray(segArray) 195 | del ds 196 | 197 | 198 | def createPallete(numSeg): 199 | """ 200 | Return a "pallete" of 3-band colours, one for each segment. 201 | 202 | The colours are just made up, and have no particular meaning. The 203 | main criterion is that they be distinct, sufficiently so that 204 | two adjacent segments should always come out in different colours. 205 | 206 | Return value is an array of shape (numSeg, 3). Individual colour 207 | values are in the range [0, 10000], and so the array has type uint16. 208 | 209 | Note that the index into this array would be (segmentID - 1). 210 | """ 211 | MINVAL = 0 212 | MAXVAL = 10000 213 | step = (MAXVAL - MINVAL) / (numSeg - 1) 214 | mid = numSeg / 2 215 | 216 | c = numpy.zeros((numSeg, NBANDS), dtype=numpy.uint16) 217 | 218 | for i in range(numSeg): 219 | c[i, 0] = round(MINVAL + i * step) 220 | c[i, 1] = round(MAXVAL - i * step) 221 | if i < mid: 222 | c[i, 2] = round(MINVAL + i * 2 * step) 223 | else: 224 | c[i, 2] = round(MAXVAL - (i - mid) * 2 * step) 225 | 226 | return c 227 | 228 | 229 | def createMultispectral(truesegfile, outfile): 230 | """ 231 | Reads the given true segment file, and generates a multi-spectral 232 | image which should segment in a similar way. 233 | """ 234 | trueseg = readSeg(truesegfile) 235 | numSeg = trueseg.max() 236 | outNull = 2**16 - 1 237 | 238 | pallete = createPallete(numSeg) 239 | 240 | (nRows, nCols) = trueseg.shape 241 | outBand = numpy.zeros(trueseg.shape, dtype=numpy.uint16) 242 | nullNdx = (trueseg == shepseg.SEGNULLVAL) 243 | 244 | segSize = shepseg.makeSegSize(trueseg) 245 | segLoc = shepseg.makeSegmentLocations(trueseg, segSize) 246 | 247 | # Open output file 248 | drvr = gdal.GetDriverByName('KEA') 249 | if os.path.exists(outfile): 250 | drvr.Delete(outfile) 251 | ds = drvr.Create(outfile, nCols, nRows, bands=NBANDS, eType=gdal.GDT_UInt16) 252 | 253 | # Generate each output band, writing as we go. 254 | for i in range(NBANDS): 255 | for segId in segLoc: 256 | segNdx = segLoc[shepseg.SegIdType(segId)].getSegmentIndices() 257 | outBand[segNdx] = pallete[segId - 1][i] 258 | outBand[nullNdx] = outNull 259 | 260 | b = ds.GetRasterBand(i + 1) 261 | b.SetNoDataValue(outNull) 262 | b.WriteArray(outBand) 263 | b.FlushCache() 264 | ds.FlushCache() 265 | del ds 266 | 267 | 268 | def readSeg(segfile, xoff=0, yoff=0, win_xsize=None, win_ysize=None): 269 | """ 270 | Open and read the given segfile. Return an image array of the 271 | segment ID values 272 | """ 273 | ds = gdal.Open(segfile) 274 | band = ds.GetRasterBand(1) 275 | seg = band.ReadAsArray(xoff, yoff, win_xsize, win_ysize).astype(shepseg.SegIdType) 276 | return seg 277 | 278 | 279 | def makeRATcolumns(segResults, outsegfile, imagefile): 280 | """ 281 | Add some columns to the RAT, with useful per-segment statistics 282 | """ 283 | # Calculate per-segment mean and stddev for all bands, and store in the RAT 284 | meanColNames = [] 285 | stdColNames = [] 286 | for i in range(NBANDS): 287 | meanCol = "Band_{}_mean".format(i + 1) 288 | stdCol = "Band_{}_stddev".format(i + 1) 289 | meanColNames.append(meanCol) 290 | stdColNames.append(stdCol) 291 | statsSelection = [(meanCol, "mean"), (stdCol, "stddev")] 292 | tilingstats.calcPerSegmentStatsTiled(imagefile, (i + 1), outsegfile, 293 | statsSelection) 294 | 295 | return (meanColNames, stdColNames) 296 | 297 | 298 | def makeSpatialRATColumns(segfile, imagefile): 299 | """ 300 | Create some RAT columns for checking the spatial stats 301 | functionality. 302 | 303 | Here we use the tilingstats.userFuncMeanCoord function passed to 304 | calcPerSegmentSpatialStatsTiled so that the mean eastings and northings 305 | are calculated. These can be easily checked later on. 306 | 307 | """ 308 | # a fake transform array. Created so that the eastings and northings 309 | # are equivalent to cols/rows for simplicity. 310 | transform = numpy.array([0, 1, 0, 0, 0, 1]) 311 | # just do the first band as they should all be the same... 312 | eastingCol = "Band_1_easting" 313 | northingCol = "Band_1_northing" 314 | colNamesAndTypes = [(eastingCol, gdal.GFT_Real), 315 | (northingCol, gdal.GFT_Real)] 316 | # call calcPerSegmentSpatialStatsTiled to do the stats 317 | tilingstats.calcPerSegmentSpatialStatsTiled(imagefile, 1, segfile, 318 | colNamesAndTypes, tilingstats.userFuncMeanCoord, transform) 319 | 320 | # return the names of the columns 321 | return (eastingCol, northingCol) 322 | 323 | 324 | def checkSegmentation(imgfile, segfile, meanColNames, stdColNames): 325 | """ 326 | Check whether the given segmentation of the given image file 327 | is "correct", by some measure(s). 328 | 329 | """ 330 | seg = readSeg(segfile) 331 | nonNull = (seg != shepseg.SEGNULLVAL) 332 | 333 | # The tolerance to use for testing equality. The spectral differences 334 | # between segments are much larger than 1, but we are comparing 335 | # single pixel spectra with the segment means. If a single pixel 336 | # is incorrectly placed, then the segment mean will be only slightly 337 | # affected, but the pixel spectra will disagree to a much greater 338 | # amount. So, a tolerance of almost 1 will still detect the single 339 | # pixels which are incorrectly placed. 340 | TOL = 0.5 341 | 342 | colourMatch = None 343 | ds = gdal.Open(imgfile) 344 | for i in range(NBANDS): 345 | bandobj = ds.GetRasterBand(i + 1) 346 | img = bandobj.ReadAsArray() 347 | 348 | segmeans = readColumn(segfile, meanColNames[i]) 349 | 350 | # An img of the segmean for this band, for each pixel. 351 | segColour = segmeans[seg] 352 | 353 | diff = numpy.absolute(img - segColour) 354 | diff[~nonNull] = 0 # Do this properly!!!!! 355 | # Per-pixel, True when segment mean matches image colour for this band 356 | colourMatch_band = (diff < TOL) 357 | 358 | # Accumulate matches across bands. Ultimately, it is 359 | # a match if all bands match. 360 | if colourMatch is None: 361 | colourMatch = colourMatch_band 362 | else: 363 | colourMatch = (colourMatch & colourMatch_band) 364 | 365 | numColourMatch = numpy.count_nonzero(colourMatch) 366 | 367 | # Rough check that nulls are in the right places 368 | imgnullval = bandobj.GetNoDataValue() 369 | nullMatch = (img[~nonNull] == imgnullval) 370 | numNullMatch = numpy.count_nonzero(nullMatch) 371 | 372 | # Percentage of pixels which match, either full colour match, or null 373 | numPix = len(colourMatch.flatten()) + len(nullMatch) 374 | pcntMatch = 100 * (numColourMatch + numNullMatch) / numPix 375 | 376 | return pcntMatch 377 | 378 | 379 | def checkSpatialColumns(segfile, eastingCol, northingCol): 380 | """ 381 | Do a quick check of eastingCol and northingCol which were 382 | calculated using calcPerSegmentSpatialStatsTiled(). 383 | 384 | Returns True if the calculated coordinates are the same 385 | as the actual coordinates (within a tolerance). 386 | 387 | """ 388 | # read in the data to check 389 | eastingData = readColumn(segfile, eastingCol) 390 | northingData = readColumn(segfile, northingCol) 391 | 392 | # read in the segfile 393 | seg = readSeg(segfile) 394 | # work out the sizes of the segments (needed for makeSegmentLocations) 395 | segSize = shepseg.makeSegSize(seg) 396 | TOL = 0.0003 397 | 398 | ok = True 399 | # get the locations of all the segments 400 | segLoc = shepseg.makeSegmentLocations(seg, segSize) 401 | for segId in segLoc: 402 | # for each segment, get the mean coordinates 403 | norths, easts = segLoc[shepseg.SegIdType(segId)].getSegmentIndices() 404 | # find the difference from calculated. 405 | xdiff = abs(easts.mean() - eastingData[segId]) 406 | ydiff = abs(norths.mean() - northingData[segId]) 407 | if xdiff > TOL or ydiff > TOL: 408 | ok = False 409 | break 410 | 411 | return ok 412 | 413 | 414 | def checkSubset(outsegfile, subset_segfile): 415 | """ 416 | Check the behavour of tiling.subsetImage() 417 | by doing a subset and checking that the new values 418 | can successfully be translated to the old using the 419 | origSegIdColName parameter. 420 | """ 421 | subset.subsetImage(outsegfile, subset_segfile, 4000, 4000, 1000, 1000, 'KEA', 422 | origSegIdColName='orig_val') 423 | lookupcol = readColumn(subset_segfile, 'orig_val') 424 | oldvals = readSeg(outsegfile, 4000, 4000, 1000, 1000) 425 | newvals = readSeg(subset_segfile) 426 | if newvals.min() != 1: 427 | # should have restarted the count 428 | return False 429 | 430 | new2oldvals = lookupcol[newvals] 431 | return (new2oldvals == oldvals).all() 432 | 433 | 434 | def readColumn(segfile, colName): 435 | """ 436 | Read the given column from the given segmentation image file. 437 | Return an array of the column values. 438 | """ 439 | ds = gdal.Open(segfile) 440 | band = ds.GetRasterBand(1) 441 | attrTbl = band.GetDefaultRAT() 442 | numCols = attrTbl.GetColumnCount() 443 | colNameList = [attrTbl.GetNameOfCol(i) for i in range(numCols)] 444 | colNdx = colNameList.index(colName) 445 | col = attrTbl.ReadAsArray(colNdx) 446 | 447 | return col 448 | 449 | 450 | if __name__ == "__main__": 451 | main() 452 | -------------------------------------------------------------------------------- /pyshepseg/cmdline/subset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Test harness for subsetting a segmented image 5 | """ 6 | 7 | # Copyright 2021 Neil Flood and Sam Gillingham. All rights reserved. 8 | # 9 | # Permission is hereby granted, free of charge, to any person 10 | # obtaining a copy of this software and associated documentation 11 | # files (the "Software"), to deal in the Software without restriction, 12 | # including without limitation the rights to use, copy, modify, 13 | # merge, publish, distribute, sublicense, and/or sell copies of the 14 | # Software, and to permit persons to whom the Software is furnished 15 | # to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be 18 | # included in all copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 21 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 22 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 23 | # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR 24 | # ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 25 | # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 26 | # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 27 | 28 | from __future__ import print_function, division 29 | 30 | import math 31 | import argparse 32 | from osgeo import gdal 33 | from pyshepseg import subset, tilingstats 34 | gdal.UseExceptions() 35 | 36 | DFLT_OUTPUT_DRIVER = 'KEA' 37 | GDAL_DRIVER_CREATION_OPTIONS = {'KEA': [], 'HFA': ['COMPRESS=YES']} 38 | 39 | 40 | def getCmdargs(): 41 | """ 42 | Get the command line arguments. 43 | """ 44 | p = argparse.ArgumentParser() 45 | p.add_argument("-i", "--infile", required=True, 46 | help="Input file") 47 | p.add_argument("-o", "--outfile", required=True, 48 | help="Output file") 49 | group = p.add_mutually_exclusive_group(required=True) 50 | group.add_argument("--srcwin", type=int, nargs=4, 51 | metavar=('xoff', 'yoff', 'xsize', 'ysize'), 52 | help="Top left pixel coordinates and subset size (in pixels) to extract") 53 | group.add_argument("--projwin", type=float, nargs=4, 54 | metavar=('ulx', 'uly', 'lrx', 'lry'), 55 | help="Projected coordinates to extract subset from the input file") 56 | group.add_argument("--mask", help="Use extent of specified raster as subset " + 57 | "area. Also only use pixels that are != 0 in this image") 58 | p.add_argument("--origsegidcol", help="Name of column to write the original" + 59 | " segment ids") 60 | p.add_argument("-f", "--format", default=DFLT_OUTPUT_DRIVER, 61 | choices=[DFLT_OUTPUT_DRIVER, "HFA"], 62 | help="Name of output GDAL format that supports RATs (default=%(default)s)") 63 | cmdargs = p.parse_args() 64 | return cmdargs 65 | 66 | 67 | def getPixelCoords(fname, coords): 68 | """ 69 | Open the supplied file and work out what coords (ulx, uly, lrx, lry) 70 | are in pixel coordinates and return (tlx, tly, xsize, ysize) 71 | """ 72 | ulx, uly, lrx, lry = coords 73 | ds = gdal.Open(fname) 74 | transform = ds.GetGeoTransform() 75 | invTransform = gdal.InvGeoTransform(transform) 76 | 77 | pix_tlx, pix_tly = gdal.ApplyGeoTransform(invTransform, ulx, uly) 78 | pix_brx, pix_bry = gdal.ApplyGeoTransform(invTransform, lrx, lry) 79 | pix_tlx = int(pix_tlx) 80 | pix_tly = int(pix_tly) 81 | pix_brx = int(math.ceil(pix_brx)) 82 | pix_bry = int(math.ceil(pix_bry)) 83 | 84 | if (pix_tlx < 0 or pix_tly < 0 or pix_brx >= ds.RasterXSize or 85 | pix_bry >= ds.RasterYSize): 86 | msg = 'Specified coordinates not completely within image' 87 | raise ValueError(msg) 88 | 89 | xsize = pix_brx - pix_tlx 90 | ysize = pix_bry - pix_tly 91 | return pix_tlx, pix_tly, xsize, ysize 92 | 93 | 94 | def getExtentOfMaskForInfile(infile, maskfile): 95 | """ 96 | Get the extent of maskfile in the pixel coordinates of infile. 97 | Returns (tlx, tly, xsize, ysize) 98 | """ 99 | inds = gdal.Open(infile) 100 | in_transform = inds.GetGeoTransform() 101 | 102 | maskds = gdal.Open(maskfile) 103 | mask_transform = maskds.GetGeoTransform() 104 | 105 | if not tilingstats.equalProjection(inds.GetProjection(), 106 | maskds.GetProjection()): 107 | msg = "Mask and infile don't have same projection" 108 | 109 | if (in_transform[1] != mask_transform[1] or 110 | in_transform[5] != mask_transform[5]): 111 | msg = "Mask and infile don't have same pixel size" 112 | raise ValueError(msg) 113 | 114 | if ((in_transform[0] - mask_transform[0]) % in_transform[1]) != 0: 115 | msg = "Mask and infile not on same grid" 116 | raise ValueError(msg) 117 | 118 | if ((in_transform[3] - mask_transform[3]) % in_transform[5]) != 0: 119 | msg = "Mask and infile not on same grid" 120 | raise ValueError(msg) 121 | 122 | mask_tlx, mask_tly = gdal.ApplyGeoTransform(mask_transform, 0, 0) 123 | mask_brx, mask_bry = gdal.ApplyGeoTransform(mask_transform, 124 | maskds.RasterXSize, maskds.RasterYSize) 125 | 126 | inv_transform = gdal.InvGeoTransform(in_transform) 127 | tlx, tly = gdal.ApplyGeoTransform(inv_transform, mask_tlx, mask_tly) 128 | brx, bry = gdal.ApplyGeoTransform(inv_transform, mask_brx, mask_bry) 129 | tlx = int(tlx) 130 | tly = int(tly) 131 | brx = int(brx) 132 | bry = int(bry) 133 | xsize = brx - tlx 134 | ysize = bry - tly 135 | # note - check that coords are within infile is made in subsetImage() 136 | return tlx, tly, xsize, ysize 137 | 138 | 139 | def main(): 140 | cmdargs = getCmdargs() 141 | 142 | tlx = None 143 | tly = None 144 | xsize = None 145 | ysize = None 146 | if cmdargs.srcwin is not None: 147 | tlx, tly, xsize, ysize = cmdargs.srcwin 148 | elif cmdargs.projwin is not None: 149 | tlx, tly, xsize, ysize = getPixelCoords(cmdargs.infile, 150 | cmdargs.projwin) 151 | else: 152 | tlx, tly, xsize, ysize = getExtentOfMaskForInfile(cmdargs.infile, 153 | cmdargs.mask) 154 | 155 | creationOptions = [] 156 | if cmdargs.format in GDAL_DRIVER_CREATION_OPTIONS: 157 | creationOptions = GDAL_DRIVER_CREATION_OPTIONS[cmdargs.format] 158 | 159 | subset.subsetImage(cmdargs.infile, cmdargs.outfile, tlx, tly, 160 | xsize, ysize, cmdargs.format, creationOptions=creationOptions, 161 | origSegIdColName=cmdargs.origsegidcol, maskImage=cmdargs.mask) 162 | 163 | 164 | if __name__ == "__main__": 165 | main() 166 | -------------------------------------------------------------------------------- /pyshepseg/cmdline/tiling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Testing harness for running pyshepseg in the tiling mode. 5 | Handy for running a basic segmentation 6 | but it is suggested that users call the module directly from a Python 7 | script and handle things like scaling the data in an appripriate 8 | manner for their application. 9 | 10 | """ 11 | # Copyright 2021 Neil Flood and Sam Gillingham. All rights reserved. 12 | # 13 | # Permission is hereby granted, free of charge, to any person 14 | # obtaining a copy of this software and associated documentation 15 | # files (the "Software"), to deal in the Software without restriction, 16 | # including without limitation the rights to use, copy, modify, 17 | # merge, publish, distribute, sublicense, and/or sell copies of the 18 | # Software, and to permit persons to whom the Software is furnished 19 | # to do so, subject to the following conditions: 20 | # 21 | # The above copyright notice and this permission notice shall be 22 | # included in all copies or substantial portions of the Software. 23 | # 24 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 26 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 27 | # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR 28 | # ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 29 | # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 30 | # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 31 | 32 | from __future__ import print_function, division 33 | 34 | import sys 35 | import time 36 | import argparse 37 | import json 38 | 39 | from osgeo import gdal 40 | 41 | from pyshepseg import tiling 42 | from pyshepseg import tilingstats 43 | from pyshepseg import utils 44 | 45 | DFLT_OUTPUT_DRIVER = 'KEA' 46 | GDAL_DRIVER_CREATION_OPTIONS = {'KEA': [], 'HFA': ['COMPRESS=YES']} 47 | 48 | DFLT_MAX_SPECTRAL_DIFF = 'auto' 49 | 50 | CLUSTER_CNTRS_METADATA_NAME = 'pyshepseg_cluster_cntrs' 51 | 52 | 53 | def getCmdargs(): 54 | """ 55 | Get the command line arguments. 56 | """ 57 | p = argparse.ArgumentParser() 58 | p.add_argument("-i", "--infile", help="Input Raster file") 59 | p.add_argument("-o", "--outfile") 60 | p.add_argument("--verbose", default=False, action="store_true", 61 | help="Turn on verbose output.") 62 | p.add_argument("--nullvalue", default=None, type=int, 63 | help="Null value for input image. If not given, the value set in the "+ 64 | "image is used.") 65 | p.add_argument("-f", "--format", default=DFLT_OUTPUT_DRIVER, 66 | choices=[DFLT_OUTPUT_DRIVER, "HFA"], 67 | help="Name of output GDAL format that supports RATs (default=%(default)s)") 68 | 69 | segGroup = p.add_argument_group("Segmentation Parameters") 70 | tileGroup = p.add_argument_group("Tiling Parameters") 71 | statsGroup = p.add_argument_group("Per-segment Statistics") 72 | concGroup = p.add_argument_group("Concurrency") 73 | 74 | segGroup.add_argument("-n", "--nclusters", default=60, type=int, 75 | help="Number of clusters (default=%(default)s)") 76 | segGroup.add_argument("--eightway", default=False, action="store_true", 77 | help="Use 8-way instead of 4-way") 78 | segGroup.add_argument("-m", "--maxspectraldiff", default=DFLT_MAX_SPECTRAL_DIFF, 79 | help=("Maximum Spectral Difference to use when merging " + 80 | "segments Either 'auto', 'none' or a value to use " + 81 | "(default=%(default)s)")) 82 | segGroup.add_argument("-s", "--minsegmentsize", default=100, type=int, 83 | help="Minimum segment size in pixels (default=%(default)s)") 84 | segGroup.add_argument("-b", "--bands", default="3,4,5", 85 | help="Comma-separated list of bands to use. 1-based. (default=%(default)s)") 86 | segGroup.add_argument("--fixedkmeansinit", default=False, action="store_true", 87 | help=("Used a fixed algorithm to select guesses at cluster centres. "+ 88 | "Default will allow the KMeans routine to select these with a "+ 89 | "random element, which can make the final results slightly "+ 90 | "non-determinstic. Use this if you want a completely "+ 91 | "deterministic, reproducable result")) 92 | 93 | tileGroup.add_argument("-t", "--tilesize", default=tiling.DFLT_TILESIZE, 94 | help="Size (in pixels) of tiles to chop input image into for processing."+ 95 | " (default=%(default)s)", type=int) 96 | tileGroup.add_argument("-l", "--overlapsize", default=tiling.DFLT_OVERLAPSIZE, 97 | help="Size (in pixels) of the overlap between tiles."+ 98 | " (default=%(default)s)", type=int) 99 | tileGroup.add_argument("-c", "--clustersubsamplepercent", default=None, type=float, 100 | help=("Percent of data to subsample for clustering (i.e. across all "+ 101 | "tiles). If not given, 1 million pixels are used.")) 102 | tileGroup.add_argument("--simplerecode", default=False, action="store_true", 103 | help=("Use a simple recode method when merging tiles, rather "+ 104 | "than merge segments across the tile boundary. This is mainly "+ 105 | "for use when testing the default merge/recode. ")) 106 | 107 | statsGroup.add_argument("--statsbands", help=("Comma-separated list of "+ 108 | "bands in the input raster file for which to calculate per-segment "+ 109 | "statistics, as columns in a raster attribute table (RAT). "+ 110 | "Default will not calculate any per-segment statistics. ")) 111 | statsGroup.add_argument("--statspec", default=[], action="append", 112 | help=("Specification of a statistic to be included in the "+ 113 | "per-segment statistics in the raster attribute table (RAT). "+ 114 | "This can be given more than once, and the nominated statistic "+ 115 | "will be calculated for all bands given in --statsbands. "+ 116 | "Options are 'mean', 'stddev', 'min', 'max', 'median', 'mode' or "+ 117 | "'percentile,p' (where 'p' is a percentile (0-100) to calculate). ")) 118 | statsGroup.add_argument("--colortablebands", help=("Comma-separated list "+ 119 | "of 3 band numbers to use for coloring of segments. Assumes that "+ 120 | "the per-segment mean has been calculated for each band, and this "+ 121 | "is used to derive a colour. Band numbers are used in the order "+ 122 | "red,green,blue")) 123 | 124 | concGroup.add_argument("--concurrencytype", default=tiling.CONC_NONE, 125 | choices=[tiling.CONC_NONE, tiling.CONC_THREADS, tiling.CONC_FARGATE, 126 | tiling.CONC_SUBPROC], 127 | help="Type of concurrency to use in tiled segmentation (default=%(default)s)") 128 | concGroup.add_argument("--numworkers", default=0, type=int, 129 | help="Number of workers for concurrent segmentation (default=%(default)s)") 130 | concGroup.add_argument("--fargatecfg", help=("JSON file of keyword " + 131 | "arguments dictionary for FargateConfig constructor " + 132 | "(for use with CONC_FARGATE)")) 133 | concGroup.add_argument("--tilecompletiontimeout", type=int, default=60, 134 | help=("Timeout (seconds) to wait for completion of each tile " + 135 | "(default=%(default)s)")) 136 | 137 | cmdargs = p.parse_args() 138 | 139 | if cmdargs.infile is None: 140 | print('Must supply input file name') 141 | p.print_help() 142 | sys.exit() 143 | 144 | if cmdargs.outfile is None: 145 | print('Must supply output file name') 146 | p.print_help() 147 | sys.exit() 148 | 149 | try: 150 | cmdargs.maxspectraldiff = float(cmdargs.maxspectraldiff) 151 | except ValueError: 152 | # check for 'auto' or 'none' 153 | if cmdargs.maxspectraldiff not in ('auto', 'none'): 154 | print("Only 'auto', 'none' or a value supported for --maxspectraldiff") 155 | p.print_help() 156 | sys.exit() 157 | 158 | # code expects 'none' -> None 159 | if cmdargs.maxspectraldiff == 'none': 160 | cmdargs.maxspectraldiff = None 161 | 162 | # turn string of bands into list of ints 163 | cmdargs.bands = [int(x) for x in cmdargs.bands.split(',')] 164 | if cmdargs.statsbands is not None: 165 | cmdargs.statsbands = [int(x) for x in cmdargs.statsbands.split(',')] 166 | else: 167 | cmdargs.statsbands = [] 168 | # Check that requested color table bands can be used for this. 169 | if cmdargs.colortablebands is not None: 170 | cmdargs.colortablebands = [int(x) for x in 171 | cmdargs.colortablebands.split(',')] 172 | if cmdargs.statspec is None or 'mean' not in cmdargs.statspec: 173 | print('Using --colortablebands requires "--statspec mean"') 174 | sys.exit() 175 | for i in cmdargs.colortablebands: 176 | if i not in cmdargs.statsbands: 177 | print("Bands given in --colortablebands must also be in --statsbands") 178 | sys.exit() 179 | 180 | return cmdargs 181 | 182 | 183 | def main(): 184 | cmdargs = getCmdargs() 185 | 186 | creationOptions = [] 187 | if cmdargs.format in GDAL_DRIVER_CREATION_OPTIONS: 188 | creationOptions = GDAL_DRIVER_CREATION_OPTIONS[cmdargs.format] 189 | 190 | fargateCfg = None 191 | if cmdargs.fargatecfg is not None: 192 | fargateCfg_kwArgs = json.load(open(cmdargs.fargatecfg)) 193 | fargateCfg = tiling.FargateConfig(**fargateCfg_kwArgs) 194 | concurrencyCfg = tiling.SegmentationConcurrencyConfig( 195 | concurrencyType=cmdargs.concurrencytype, 196 | numWorkers=cmdargs.numworkers, 197 | fargateCfg=fargateCfg, 198 | tileCompletionTimeout=cmdargs.tilecompletiontimeout) 199 | 200 | tiledSegResult = tiling.doTiledShepherdSegmentation(cmdargs.infile, cmdargs.outfile, 201 | tileSize=cmdargs.tilesize, overlapSize=cmdargs.overlapsize, 202 | minSegmentSize=cmdargs.minsegmentsize, numClusters=cmdargs.nclusters, 203 | bandNumbers=cmdargs.bands, subsamplePcnt=cmdargs.clustersubsamplepercent, 204 | maxSpectralDiff=cmdargs.maxspectraldiff, imgNullVal=cmdargs.nullvalue, 205 | fixedKMeansInit=cmdargs.fixedkmeansinit, 206 | fourConnected=not cmdargs.eightway, verbose=cmdargs.verbose, 207 | simpleTileRecode=cmdargs.simplerecode, outputDriver=cmdargs.format, 208 | creationOptions=creationOptions, concurrencyCfg=concurrencyCfg) 209 | # Print timings 210 | if cmdargs.verbose and tiledSegResult.timings is not None: 211 | summaryDict = tiledSegResult.timings.makeSummaryDict() 212 | print('\n' + utils.formatTimingRpt(summaryDict) + '\n') 213 | 214 | # Do a colour table on final output file. 215 | outDs = gdal.Open(cmdargs.outfile, gdal.GA_Update) 216 | band = outDs.GetRasterBand(1) 217 | 218 | if cmdargs.colortablebands is None: 219 | utils.writeRandomColourTable(band, tiledSegResult.maxSegId + 1) 220 | 221 | del outDs 222 | 223 | t0 = time.time() 224 | doPerSegmentStats(cmdargs) 225 | if cmdargs.verbose: 226 | print('Done per-segment statistics: {:.2f} seconds'.format(time.time() - t0)) 227 | 228 | if cmdargs.colortablebands is not None: 229 | colorTableNames = ['Band_{}_mean'.format(i) for i in cmdargs.colortablebands] 230 | utils.writeColorTableFromRatColumns(cmdargs.outfile, 231 | colorTableNames[0], colorTableNames[1], colorTableNames[2]) 232 | 233 | 234 | def doPerSegmentStats(cmdargs): 235 | """ 236 | If requested, calculate RAT columns of per-segment statistics 237 | """ 238 | for statsBand in cmdargs.statsbands: 239 | statsSelection = [] 240 | for statsSpec in cmdargs.statspec: 241 | if statsSpec.startswith('percentile,'): 242 | param = int(statsSpec.split(',')[1]) 243 | name = "Band_{}_pcnt{}".format(statsBand, param) 244 | selection = (name, 'percentile', param) 245 | else: 246 | name = "Band_{}_{}".format(statsBand, statsSpec) 247 | selection = (name, statsSpec) 248 | statsSelection.append(selection) 249 | 250 | rtn = tilingstats.calcPerSegmentStatsTiled(cmdargs.infile, statsBand, 251 | cmdargs.outfile, statsSelection) 252 | 253 | if cmdargs.verbose: 254 | timingsSummary = rtn.timings.makeSummaryDict() 255 | print(utils.formatTimingRpt(timingsSummary) + '\n') 256 | 257 | 258 | if __name__ == "__main__": 259 | main() 260 | -------------------------------------------------------------------------------- /pyshepseg/cmdline/variograms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Test harness for tilingstats.calcPerSegmentSpatialStatsTiled(). 5 | Calculates the given number of variograms and saves them to the 6 | segmented file's RAT. 7 | """ 8 | 9 | # Copyright 2021 Neil Flood and Sam Gillingham. All rights reserved. 10 | # 11 | # Permission is hereby granted, free of charge, to any person 12 | # obtaining a copy of this software and associated documentation 13 | # files (the "Software"), to deal in the Software without restriction, 14 | # including without limitation the rights to use, copy, modify, 15 | # merge, publish, distribute, sublicense, and/or sell copies of the 16 | # Software, and to permit persons to whom the Software is furnished 17 | # to do so, subject to the following conditions: 18 | # 19 | # The above copyright notice and this permission notice shall be 20 | # included in all copies or substantial portions of the Software. 21 | # 22 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 23 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 24 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 25 | # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR 26 | # ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 27 | # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 28 | # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 29 | 30 | import argparse 31 | 32 | from osgeo import gdal 33 | 34 | from pyshepseg import tilingstats 35 | 36 | 37 | def getCmdargs(): 38 | """ 39 | Get the command line arguments. 40 | """ 41 | p = argparse.ArgumentParser() 42 | p.add_argument("-i", "--infile", required=True, 43 | help="Input file to collect stats from") 44 | p.add_argument("-s", "--segfile", required=True, 45 | help="File from segmentation. Note: stats is written into the RAT in this file") 46 | p.add_argument("-n", "--numvariograms", required=True, 47 | choices=[x for x in range(1, 10)], type=int, 48 | help="Number of variograms to calculate") 49 | cmdargs = p.parse_args() 50 | return cmdargs 51 | 52 | 53 | def main(): 54 | cmdargs = getCmdargs() 55 | cols = [] 56 | for n in range(cmdargs.numvariograms): 57 | cols.append(("variogram{}".format(n + 1), gdal.GFT_Real)) 58 | 59 | tilingstats.calcPerSegmentSpatialStatsTiled(cmdargs.infile, 1, 60 | cmdargs.segfile, cols, tilingstats.userFuncVariogram, 61 | cmdargs.numvariograms) 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /pyshepseg/guardeddecorators.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dreadful hack to get around the fact that the numpydoc Sphinx extension 3 | does not play well with numba's jitclass decorator. 4 | 5 | If we are running with Sphinx, then fake the jitclass decorator so that it 6 | just returns the class it is decorating. 7 | 8 | """ 9 | import sys 10 | 11 | if 'sphinx' not in sys.modules: 12 | from numba.experimental import jitclass 13 | else: 14 | def jitclass(cls_or_spec=None, spec=None): 15 | """ 16 | Our fake jitclass decorator. Hacked from the real one 17 | in numba. 18 | 19 | Returns 20 | ------- 21 | If used as a decorator, returns a callable that takes a class 22 | object and returns the same class. In short, this decorator does 23 | nothing at all. 24 | 25 | """ 26 | 27 | if (cls_or_spec is not None and 28 | spec is None and 29 | not isinstance(cls_or_spec, type)): 30 | # Used like 31 | # @jitclass([("x", intp)]) 32 | # class Foo: 33 | # ... 34 | spec = cls_or_spec 35 | cls_or_spec = None 36 | 37 | def wrap(cls): 38 | return cls 39 | 40 | if cls_or_spec is None: 41 | return wrap 42 | else: 43 | return wrap(cls_or_spec) 44 | -------------------------------------------------------------------------------- /pyshepseg/shepseg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python implementation of the image segmentation algorithm described 3 | by Shepherd et al [1]_ 4 | 5 | Implemented using scikit-learn's K-Means algorithm [2]_, and using 6 | numba [3]_ compiled code for the other main components. 7 | 8 | Main entry point is the doShepherdSegmentation() function. 9 | 10 | Examples 11 | -------- 12 | 13 | Read in a multi-band image as a single array, img, 14 | of shape (nBands, nRows, nCols). 15 | Ensure that any null pixels are all set to a known 16 | null value in all bands. Failure to correctly identify 17 | null pixels can result in a poorer quality segmentation. 18 | 19 | >>> from pyshepseg import shepseg 20 | >>> segRes = shepseg.doShepherdSegmentation(img, imgNullVal=nullVal) 21 | 22 | The segimg attribute of the segRes object is an array 23 | of segment ID numbers, of shape (nRows, nCols). 24 | 25 | Resulting segment ID numbers start from 1, and null pixels 26 | are set to zero. 27 | 28 | **Efficient segment location** 29 | 30 | After segmentation, the location of individual segments can be 31 | found efficiently using the object returned by makeSegmentLocations(). 32 | 33 | >>> segSize = shepseg.makeSegSize(segimg) 34 | >>> segLoc = shepseg.makeSegmentLocations(segimg, segSize) 35 | 36 | This segLoc object is indexed by segment ID number (must be 37 | of type shepseg.SegIdType), and each element contains information 38 | about the pixels which are in that segment. This information 39 | can be returned as a slicing object suitable to index the image array 40 | 41 | >>> segNdx = segLoc[segId].getSegmentIndices() 42 | >>> vals = img[0][segNdx] 43 | 44 | This example would give an array of the pixel values from the first 45 | band of the original image, for the given segment ID. 46 | 47 | This can be a very efficient way to calculate per-segment 48 | quantities. It can be used in pure Python code, or it can be used 49 | inside numba jit functions, for even greater efficiency. 50 | 51 | References 52 | ---------- 53 | .. [1] Shepherd, J., Bunting, P. and Dymond, J. (2019). 54 | Operational Large-Scale Segmentation of Imagery Based on 55 | Iterative Elimination. Remote Sensing 11(6). 56 | https://www.mdpi.com/2072-4292/11/6/658 57 | .. [2] https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html 58 | .. [3] https://numba.pydata.org/ 59 | 60 | 61 | """ 62 | # Copyright 2021 Neil Flood and Sam Gillingham. All rights reserved. 63 | # 64 | # Permission is hereby granted, free of charge, to any person 65 | # obtaining a copy of this software and associated documentation 66 | # files (the "Software"), to deal in the Software without restriction, 67 | # including without limitation the rights to use, copy, modify, 68 | # merge, publish, distribute, sublicense, and/or sell copies of the 69 | # Software, and to permit persons to whom the Software is furnished 70 | # to do so, subject to the following conditions: 71 | # 72 | # The above copyright notice and this permission notice shall be 73 | # included in all copies or substantial portions of the Software. 74 | # 75 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 76 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 77 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 78 | # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR 79 | # ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 80 | # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 81 | # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 82 | 83 | # Just in case anyone is trying to use this with Python-2 84 | from __future__ import print_function, division 85 | 86 | import sys 87 | import time 88 | 89 | import numpy 90 | from sklearn.cluster import KMeans 91 | from numba import njit 92 | from .guardeddecorators import jitclass 93 | from numba.core import types 94 | from numba.typed import Dict 95 | 96 | # A symbol for the data type used as a segment ID number 97 | SegIdType = numpy.uint32 98 | 99 | # This value used for null in both cluster ID and segment ID images 100 | SEGNULLVAL = 0 101 | MINSEGID = SEGNULLVAL + 1 102 | 103 | 104 | class SegmentationResult(object): 105 | """ 106 | Results of the segmentation process 107 | 108 | Attributes 109 | ---------- 110 | segimg : numpy array (nRows, nCols) 111 | Elements are segment ID numbers (starting from 1) 112 | kmeans : sklearn.cluster.KMeans 113 | Fitted KMeans object 114 | maxSpectralDiff : float 115 | The value used to limit segment merging 116 | singlePixelsEliminated : int 117 | Number of single pixels merged to adjacent segments 118 | smallSegmentsEliminated : int 119 | Number of small segments merged into adjacent segments 120 | 121 | """ 122 | def __init__(self): 123 | self.segimg = None 124 | self.kmeans = None 125 | self.maxSpectralDiff = None 126 | self.singlePixelsEliminated = None 127 | self.smallSegmentsEliminated = None 128 | 129 | 130 | def doShepherdSegmentation(img, numClusters=60, clusterSubsamplePcnt=1, 131 | minSegmentSize=50, maxSpectralDiff='auto', imgNullVal=None, 132 | fourConnected=True, verbose=False, fixedKMeansInit=False, 133 | kmeansObj=None, spectDistPcntile=50): 134 | """ 135 | Perform Shepherd segmentation in memory, on the given 136 | multi-band img array. 137 | 138 | Parameters 139 | ---------- 140 | img : integer ndarray of shape (nBands, nRows, nCols) 141 | numClusters : int 142 | Number of clusters to create with k-means clustering 143 | clusterSubsamplePcnt : int 144 | Passed to fitSpectralClusters(). See there for details 145 | minSegmentSize : int 146 | the minimum segment size (in pixels) which will be left 147 | after eliminating small segments (except for segments which 148 | cannot be eliminated). 149 | maxSpectralDiff : str or float 150 | sets a limit on how different segments can be and still be merged. 151 | It is given in the units of the spectral space of img. If 152 | maxSpectralDiff is 'auto', a default value will be calculated 153 | from the spectral distances between cluster centres, as a 154 | percentile of the distribution of these (spectDistPcntile). 155 | The value of spectDistPcntile should be lowered when segementing 156 | an image with a larger range of spectral distances. 157 | spectDistPcntile : int 158 | See maxSpectralDiff 159 | fourConnected : bool 160 | If True, use 4-way connectedness when clumping, otherwise use 161 | 8-way 162 | imgNullVal : int or None 163 | If not None, then pixels with this value in any band are set to zero 164 | (SEGNULLVAL) in the output segmentation. If there are null values 165 | in the image array, it is important to give this null value, as it can 166 | stringly affect the initial spectral clustering, which in turn 167 | strongly affects the final segmenation. 168 | fixedKMeansInit : bool 169 | If fixedKMeansInit is True, then choose a fixed set of 170 | cluster centres to initialize the KMeans algorithm. This can 171 | be useful to provide strict determinacy of the results by 172 | avoiding sklearn's multiple random initial guesses. The default 173 | is to allow sklearn to guess, which is good for avoiding 174 | local minima. 175 | kmeansObj : sklearn.cluster.KMeans object 176 | By default, the spectral clustering step will be fitted using 177 | the given img. However, if kmeansObj is not None, it is taken 178 | to be a fitted instance of sklearn.cluster.KMeans, and will 179 | be used instead. This is useful when enforcing a consistent 180 | clustering across multiple tiles (see the pyshepseg.tiling 181 | module for details). 182 | 183 | Returns 184 | ------- 185 | segResult : SegmentationResult object 186 | 187 | Notes 188 | ----- 189 | Default values are mostly as suggested by Shepherd et al. 190 | 191 | Segment ID numbers start from 1. Zero is a NULL segment ID. 192 | 193 | The return value is an instance of SegmentationResult class. 194 | 195 | See Also 196 | -------- 197 | pyshepseg.tiling, fitSpectralClusters 198 | 199 | """ 200 | t0 = time.time() 201 | if kmeansObj is not None: 202 | km = kmeansObj 203 | else: 204 | km = fitSpectralClusters(img, numClusters, 205 | clusterSubsamplePcnt, imgNullVal, fixedKMeansInit) 206 | clusters = applySpectralClusters(km, img, imgNullVal) 207 | if verbose: 208 | print("Kmeans, in", round(time.time() - t0, 1), "seconds") 209 | 210 | # Do clump 211 | t0 = time.time() 212 | (seg, maxSegId) = clump(clusters, SEGNULLVAL, fourConnected=fourConnected, 213 | clumpId=MINSEGID) 214 | maxSegId = SegIdType(maxSegId - 1) 215 | if verbose: 216 | print("Found", maxSegId, "clumps, in", round(time.time() - t0, 1), "seconds") 217 | 218 | # Make segment size array 219 | segSize = makeSegSize(seg) 220 | 221 | # Eliminate small segments. Start with James' 222 | # memory-efficient method for single pixel clumps. 223 | t0 = time.time() 224 | oldMaxSegId = maxSegId 225 | eliminateSinglePixels(img, seg, segSize, MINSEGID, maxSegId, fourConnected) 226 | maxSegId = seg.max() 227 | numElimSinglepix = oldMaxSegId - maxSegId 228 | if verbose: 229 | print("Eliminated", numElimSinglepix, "single pixels, in", 230 | round(time.time() - t0, 1), "seconds") 231 | 232 | maxSpectralDiff = autoMaxSpectralDiff(km, maxSpectralDiff, spectDistPcntile) 233 | 234 | t0 = time.time() 235 | numElimSmall = eliminateSmallSegments(seg, img, maxSegId, minSegmentSize, maxSpectralDiff, 236 | fourConnected, MINSEGID) 237 | if verbose: 238 | print("Eliminated", numElimSmall, "segments, in", round(time.time() - t0, 1), "seconds") 239 | 240 | if verbose: 241 | print("Final result has", seg.max(), "segments") 242 | 243 | segResult = SegmentationResult() 244 | segResult.segimg = seg 245 | segResult.kmeans = km 246 | segResult.maxSpectralDiff = maxSpectralDiff 247 | segResult.singlePixelsEliminated = numElimSinglepix 248 | segResult.smallSegmentsEliminated = numElimSmall 249 | return segResult 250 | 251 | 252 | def fitSpectralClusters(img, numClusters, subsamplePcnt, imgNullVal, 253 | fixedKMeansInit): 254 | """ 255 | First step of Shepherd segmentation. Use K-means clustering 256 | to create a set of "seed" segments, labelled only with 257 | their spectral cluster number. 258 | 259 | Parameters 260 | ---------- 261 | img : int ndarray (nBands, nRows, nCols). 262 | numClusters : int 263 | The number of clusters for the KMeans algorithm to find 264 | (i.e. it is 'k') 265 | subsamplePcnt : int 266 | The percentage of the pixels to actually use for KMeans clustering. 267 | Shepherd et al find that only a very small percentage is required. 268 | imgNullVal : int or None 269 | If imgNullVal is not None, then pixels in img with this value in 270 | any band are set to segNullVal in the output. 271 | fixedKMeansInit : bool 272 | If True, then use a simple algorithm to determine the fixed 273 | set of initial cluster centres. Otherwise allow the sklearn 274 | routine to choose its own initial guesses. 275 | 276 | Returns 277 | ------- 278 | kmeansObj : sklearn.cluster.KMeans 279 | A fitted object of class sklearn.cluster.KMeans. This 280 | is suitable to use with the applySpectralClusters() function. 281 | 282 | """ 283 | (nBands, nRows, nCols) = img.shape 284 | 285 | # Re-organise the image data so it matches what sklearn 286 | # expects. 287 | xFull = numpy.transpose(img, axes=(1, 2, 0)) 288 | xFull = xFull.reshape((nRows * nCols, nBands)) 289 | 290 | if imgNullVal is not None: 291 | # Only use non-null values for fitting 292 | nonNull = (xFull != imgNullVal).all(axis=1) 293 | xNonNull = xFull[nonNull] 294 | del nonNull 295 | else: 296 | xNonNull = xFull 297 | skip = int(round(100. / subsamplePcnt)) 298 | xSample = xNonNull[::skip] 299 | del xFull, xNonNull 300 | 301 | # Note that we limit the number of trials that KMeans does, using 302 | # the n_init argument. Multiple trials are used to avoid getting 303 | # stuck in local minima, but 5 seems plenty, and this is the 304 | # slowest part, so let's not get carried away. 305 | numKmeansTrials = 5 306 | 307 | init = 'k-means++' # This is sklearn's default 308 | if fixedKMeansInit: 309 | init = diagonalClusterCentres(xSample, numClusters) 310 | numKmeansTrials = 1 311 | km = KMeans(n_clusters=numClusters, n_init=numKmeansTrials, init=init) 312 | km.fit(xSample) 313 | 314 | return km 315 | 316 | 317 | def applySpectralClusters(kmeansObj, img, imgNullVal): 318 | """ 319 | Use the given KMeans object to predict spectral clusters on 320 | a whole image array. 321 | 322 | Parameters 323 | ---------- 324 | kmeansObj : sklearn.cluster.KMeans 325 | A fitted instance, as returned by fitSpectralClusters(). 326 | img : int ndarray (nBands, nRows, nCols) 327 | The image to predict on 328 | imgNullVal : int 329 | Any pixels in img which have value imgNullVal will be set to 330 | SEGNULLVAL (i.e. zero) in the output cluster image. 331 | 332 | Returns 333 | ------- 334 | segimg : int ndarray (nRows, nCols) 335 | The initial segment image, each element being the segment 336 | ID value for that pixel 337 | 338 | """ 339 | 340 | # Predict on the whole image. In principle we could omit the nulls, 341 | # but it makes little difference to run time, and just adds complexity. 342 | 343 | (nBands, nRows, nCols) = img.shape 344 | 345 | # Re-organise the image data so it matches what sklearn 346 | # expects. 347 | xFull = numpy.transpose(img, axes=(1, 2, 0)) 348 | xFull = xFull.reshape((nRows * nCols, nBands)) 349 | 350 | clustersFull = kmeansObj.predict(xFull) 351 | del xFull 352 | clustersImg = clustersFull.reshape((nRows, nCols)) 353 | 354 | # Make the cluster ID numbers start from 1, and use SEGNULLVAL 355 | # (i.e. zero) in null pixels 356 | clustersImg += 1 357 | if imgNullVal is not None: 358 | nullmask = (img == imgNullVal).any(axis=0) 359 | clustersImg[nullmask] = SEGNULLVAL 360 | 361 | return clustersImg 362 | 363 | 364 | def diagonalClusterCentres(xSample, numClusters): 365 | """ 366 | Calculate an array of initial guesses at cluster centres. 367 | This will be given to the KMeans constructor as the init 368 | parameter. 369 | 370 | The centres are evenly spaced along the diagonal of 371 | the bounding box of the data. The end points are placed 372 | 1 step in from the corners. 373 | 374 | Parameters 375 | ---------- 376 | xSample : int ndarray (numPoints, numBands) 377 | A sample of data to be used for fitting 378 | numClusters : int 379 | Number of cluster centres to be calculated 380 | 381 | Returns 382 | ------- 383 | centres : int ndarray (numPoints, numBands) 384 | Initial cluster centres in spectral space 385 | 386 | """ 387 | (numPoints, numBands) = xSample.shape 388 | bandMin = xSample.min(axis=0) 389 | bandMax = xSample.max(axis=0) 390 | 391 | centres = numpy.empty((numClusters, numBands), dtype=xSample.dtype) 392 | 393 | step = (bandMax - bandMin) / (numClusters + 1) 394 | for i in range(numClusters): 395 | centres[i] = bandMin + (i + 1) * step 396 | 397 | return centres 398 | 399 | 400 | def autoMaxSpectralDiff(km, maxSpectralDiff, distPcntile): 401 | """ 402 | Work out what to use as the maxSpectralDiff. 403 | 404 | If current value is 'auto', then return the median spectral 405 | distance between cluster centres from the KMeans clustering 406 | object km. 407 | 408 | If current value is None, return 10 times the largest distance 409 | between cluster centres (i.e. too large ever to make a difference) 410 | 411 | Otherwise, return the given current value. 412 | 413 | Parameters 414 | ---------- 415 | km : sklearn.cluster.KMeans 416 | KMeans clustering object 417 | maxSpectralDiff : str or float 418 | It is given in the units of the spectral space of img. If 419 | maxSpectralDiff is 'auto', a default value will be calculated 420 | from the spectral distances between cluster centres, as a 421 | percentile of the distribution of these (distPcntile). 422 | The value of distPcntile should be lowered when segementing 423 | an image with a larger range of spectral distances. 424 | distPcntile : int 425 | See maxSpectralDiff 426 | 427 | Returns 428 | ------- 429 | maxSpectralDiff : int 430 | The value to use as maxSpectralDiff. 431 | 432 | """ 433 | # Calculate distances between pairs of cluster centres 434 | centres = km.cluster_centers_ 435 | numClusters = centres.shape[0] 436 | numPairs = numClusters * (numClusters - 1) // 2 437 | clusterDist = numpy.full(numPairs, -1, dtype=numpy.float32) 438 | k = 0 439 | for i in range(numClusters - 1): 440 | for j in range(i + 1, numClusters): 441 | clusterDist[k] = numpy.sqrt(((centres[i] - centres[j])**2).sum()) 442 | k += 1 443 | 444 | if maxSpectralDiff == 'auto': 445 | maxSpectralDiff = numpy.percentile(clusterDist, distPcntile) 446 | elif maxSpectralDiff is None: 447 | maxSpectralDiff = 10 * clusterDist.max() 448 | 449 | return maxSpectralDiff 450 | 451 | 452 | @njit 453 | def clump(img, ignoreVal, fourConnected=True, clumpId=1): 454 | """ 455 | Implementation of clumping using Numba. 456 | 457 | Parameters 458 | ---------- 459 | img : int ndarray (nRows, nCols) 460 | Image array containing the data to be clumped. 461 | ignoreVal : int 462 | should be the "no data" value for the input 463 | fourConnected : bool 464 | If True, use 4-way connected, otherwise 8-way 465 | clumpId : int 466 | The start clump id to use 467 | 468 | Returns 469 | ------- 470 | clumpimg : SegIdType ndarray (nRows, nCols) 471 | Image array containing the clump IDs for each pixel 472 | clumpId : int 473 | The highest clumpid used + 1 474 | 475 | """ 476 | 477 | # Prevent really large clumps, as they create a 478 | # serious performance hit later. In initial testing without 479 | # this limit, final segmentation had >99.9% of segments 480 | # smaller than this, so this seems like a good size to stop. 481 | MAX_CLUMP_SIZE = 10000 482 | 483 | ysize, xsize = img.shape 484 | output = numpy.zeros((ysize, xsize), dtype=SegIdType) 485 | search_list = numpy.empty((xsize * ysize, 2), dtype=numpy.uint32) 486 | 487 | searchIdx = 0 488 | 489 | # run through the image 490 | for y in range(ysize): 491 | for x in range(xsize): 492 | # check if we have visited this one before 493 | if img[y, x] != ignoreVal and output[y, x] == 0: 494 | val = img[y, x] 495 | clumpSize = 0 496 | searchIdx = 0 497 | search_list[searchIdx, 0] = y 498 | search_list[searchIdx, 1] = x 499 | searchIdx += 1 500 | output[y, x] = clumpId # marked as visited 501 | 502 | while searchIdx > 0 and (clumpSize < MAX_CLUMP_SIZE): 503 | # search the last one 504 | searchIdx -= 1 505 | sy = search_list[searchIdx, 0] 506 | sx = search_list[searchIdx, 1] 507 | 508 | # work out the 3x3 window to vist 509 | tlx = sx - 1 510 | if tlx < 0: 511 | tlx = 0 512 | tly = sy - 1 513 | if tly < 0: 514 | tly = 0 515 | brx = sx + 1 516 | if brx > xsize - 1: 517 | brx = xsize - 1 518 | bry = sy + 1 519 | if bry > ysize - 1: 520 | bry = ysize - 1 521 | 522 | # do a '4 neighbour search' 523 | for cx in range(tlx, brx + 1): 524 | for cy in range(tly, bry + 1): 525 | connected = not fourConnected or (cy == sy or cx == sx) 526 | # don't have to check we are the middle 527 | # cell since output will be != 0 528 | # since we do that before we add it to search_list 529 | if connected and (img[cy, cx] != ignoreVal and 530 | output[cy, cx] == 0 and 531 | img[cy, cx] == val): 532 | output[cy, cx] = clumpId # mark as visited 533 | clumpSize += 1 534 | # add this one to the ones to search the neighbours 535 | search_list[searchIdx, 0] = cy 536 | search_list[searchIdx, 1] = cx 537 | searchIdx += 1 538 | 539 | clumpId += 1 540 | 541 | return (output, clumpId) 542 | 543 | 544 | @njit 545 | def makeSegSize(seg): 546 | """ 547 | Return an array of segment sizes, essentially a histogram for 548 | the segment ID values. 549 | 550 | Parameters 551 | ---------- 552 | seg : SegIdType ndarray (nRows, nCols) 553 | Image array of segment ID values 554 | 555 | Returns 556 | ------- 557 | segSize : int ndarray (numSegments+1, ) 558 | Array is indexed by segment ID. Each element is the 559 | number of pixels in that segment. 560 | 561 | """ 562 | maxSegId = seg.max() 563 | segSize = numpy.zeros(maxSegId + 1, dtype=numpy.uint32) 564 | (nRows, nCols) = seg.shape 565 | for i in range(nRows): 566 | for j in range(nCols): 567 | segSize[seg[i, j]] += 1 568 | 569 | return segSize 570 | 571 | 572 | def eliminateSinglePixels(img, seg, segSize, minSegId, maxSegId, fourConnected): 573 | """ 574 | Approximate elimination of single pixels, as suggested 575 | by Shepherd et al (section 2.3, page 6). This step suggested as 576 | an efficient way of removing a large number of segments which 577 | are single pixels, by approximating the spectrally-nearest 578 | neighbouring segment with the spectrally-nearest neighouring 579 | pixel. 580 | 581 | Parameters 582 | ---------- 583 | img : int ndarray (nBands, nRows, nCols) 584 | The original spectral image 585 | seg : SegIdType ndarray (nRows, nCols) 586 | The image of segment IDs 587 | segSize : int array (numSeg+1, ) 588 | Array of pixel counts for every segment 589 | minSegId : SegIdType 590 | Smallest segment ID 591 | maxSegId : SegIdType 592 | Largest segment ID 593 | fourConnected : bool 594 | If True use 4-way connectedness, otherwise 8-way 595 | 596 | Notes 597 | ----- 598 | 599 | Segment ID numbers start at 1 (i.e. 0 is not valid) 600 | 601 | Modifies seg array in place. 602 | 603 | """ 604 | # Array to store info on pixels to be eliminated. 605 | # Store (row, col, newSegId). 606 | segToElim = numpy.zeros((3, maxSegId), dtype=seg.dtype) 607 | 608 | totalNumElim = 0 609 | numElim = mergeSinglePixels(img, seg, segSize, segToElim, fourConnected) 610 | while numElim > 0: 611 | totalNumElim += numElim 612 | numElim = mergeSinglePixels(img, seg, segSize, segToElim, fourConnected) 613 | 614 | # Now do a relabel..... 615 | relabelSegments(seg, segSize, minSegId) 616 | 617 | 618 | @njit 619 | def mergeSinglePixels(img, seg, segSize, segToElim, fourConnected): 620 | """ 621 | Search for single-pixel segments, and decide which neighbouring 622 | segment they should be merged with. Finds all to eliminate, 623 | then performs merge on all selected. Modifies seg and 624 | segSize arrays in place, and returns the number of segments 625 | eliminated. 626 | 627 | Parameters 628 | ---------- 629 | img : int ndarray (nBands, nRows, nCols) 630 | the original spectral image 631 | seg : int ndarray (nRows, nCols) 632 | the image of segments 633 | segSize : int array (numSeg+1, ) 634 | Array of pixel counts for every segment 635 | segToElim : int ndarray (3, maxSegId) 636 | Temporary storage for segments to be eliminated 637 | fourConnected : bool 638 | If True use 4-way connectedness, otherwise 8-way 639 | 640 | Returns 641 | ------- 642 | numEliminated : int 643 | Number of segments eliminated 644 | 645 | """ 646 | (nRows, nCols) = seg.shape 647 | numEliminated = 0 648 | 649 | for i in range(nRows): 650 | for j in range(nCols): 651 | segid = seg[i, j] 652 | if segSize[segid] == 1: 653 | (ii, jj) = findNearestNeighbourPixel(img, seg, i, j, 654 | segSize, fourConnected) 655 | # Record the new segment ID for the current pixel 656 | if (ii >= 0 and jj >= 0): 657 | segToElim[0, numEliminated] = i 658 | segToElim[1, numEliminated] = j 659 | segToElim[2, numEliminated] = seg[ii, jj] 660 | numEliminated += 1 661 | 662 | # Now do eliminations, updating the seg array and the 663 | # segSize array in place. 664 | for k in range(numEliminated): 665 | r = segToElim[0, k] 666 | c = segToElim[1, k] 667 | newSeg = segToElim[2, k] 668 | oldSeg = seg[r, c] 669 | 670 | seg[r, c] = newSeg 671 | segSize[oldSeg] = 0 672 | segSize[newSeg] += 1 673 | 674 | return numEliminated 675 | 676 | 677 | @njit 678 | def findNearestNeighbourPixel(img, seg, i, j, segSize, fourConnected): 679 | """ 680 | For the (i, j) pixel, choose which of the neighbouring 681 | pixels is the most similar, spectrally. 682 | 683 | Returns the row and column of the most 684 | spectrally similar neighbour, which is also in a 685 | clump of size > 1. If none is found, return (-1, -1) 686 | 687 | Parameters 688 | ---------- 689 | img : int ndarray (nBands, nRows, nCols) 690 | Input multi-band image 691 | seg : SegIdType ndarray (nRows, nCols) 692 | Partially completed segmentation image (values are segment 693 | ID numbers) 694 | i : int 695 | Row number of target pixel 696 | j : int 697 | Column number of target pixel 698 | segSize : int ndarray (numSegments+1, ) 699 | Pixel counts, indexed by segment ID number (i.e. a histogram of 700 | the seg array) 701 | fourConnected : bool 702 | If True, use four-way connectedness to judge neighbours, otherwise 703 | use eight-way. 704 | 705 | Returns 706 | ------- 707 | ii : int 708 | Row number of the selected neighbouring pixel (-1 if not found) 709 | jj : int 710 | Column number of the selected neighbouring pixel (-1 if not found) 711 | 712 | """ 713 | (nBands, nRows, nCols) = img.shape 714 | 715 | minDsqr = -1 716 | ii = jj = -1 717 | # Cope with image edges 718 | (iiiStrt, iiiEnd) = (max(i - 1, 0), min(i + 1, nRows - 1)) 719 | (jjjStrt, jjjEnd) = (max(j - 1, 0), min(j + 1, nCols - 1)) 720 | 721 | for iii in range(iiiStrt, iiiEnd + 1): 722 | for jjj in range(jjjStrt, jjjEnd + 1): 723 | connected = ((not fourConnected) or ((iii == i) or (jjj == j))) 724 | if connected: 725 | segNbr = seg[iii, jjj] 726 | if segSize[segNbr] > 1: 727 | # Euclidean distance in spectral space. Note that because 728 | # we are only interested in the order, we don't actually 729 | # need to do the sqrt (which is expensive) 730 | dSqr = ((img[:, i, j] - img[:, iii, jjj]) ** 2).sum() 731 | if minDsqr < 0 or dSqr < minDsqr: 732 | minDsqr = dSqr 733 | ii = iii 734 | jj = jjj 735 | 736 | return (ii, jj) 737 | 738 | 739 | @njit 740 | def relabelSegments(seg, segSize, minSegId): 741 | """ 742 | The given seg array is an image of segment labels, with some 743 | numbers unused, due to elimination of small segments. Go through 744 | and find the unused numbers, and re-label segments above 745 | these so that segment labels are contiguous. 746 | 747 | Modifies the seg array in place. The segSize array is not 748 | updated, and should be recomputed. 749 | 750 | Parameters 751 | ---------- 752 | seg : SegIdType ndarray (nRows, nCols) 753 | Segmentation image. Updated in place with new segment ID values 754 | segSize : int array (numSeg+1, ) 755 | Array of pixel counts for every segment 756 | minSegId : int 757 | Smallest valid segment ID number 758 | 759 | """ 760 | oldNumSeg = len(segSize) 761 | subtract = numpy.zeros(oldNumSeg, dtype=SegIdType) 762 | 763 | # For each segid with a count of zero (i.e. it is unused), we 764 | # increase the amount by which segid numbers above this should 765 | # be decremented 766 | for k in range(minSegId + 1, oldNumSeg): 767 | subtract[k] = subtract[k - 1] 768 | if segSize[k - 1] == 0: 769 | subtract[k] += 1 770 | 771 | # Now decrement the segid of every pixel 772 | (nRows, nCols) = seg.shape 773 | for i in range(nRows): 774 | for j in range(nCols): 775 | oldSegId = seg[i, j] 776 | newSegId = oldSegId - subtract[oldSegId] 777 | seg[i, j] = newSegId 778 | 779 | 780 | @njit 781 | def buildSegmentSpectra(seg, img, maxSegId): 782 | """ 783 | Build an array of the spectral statistics for each segment. 784 | Return an array of the sums of the spectral values for each 785 | segment, for each band 786 | 787 | Parameters 788 | ---------- 789 | seg : SegIdType ndarray (nRows, nCols) 790 | Segmentation image 791 | img : Integer ndarray (nBands, nRows, nCols) 792 | Input multi-band image 793 | maxSegId : int 794 | Largest segment ID number in seg 795 | 796 | Returns 797 | ------- 798 | spectSum : float32 ndarray (numSegments+1, nBands) 799 | Sums of all pixel values. Element [i, j] is the sum of all 800 | values in img for the j-th band, which have segment ID i. The 801 | row for i==0 is unused, as zero is not a valid segment ID. 802 | 803 | """ 804 | (nBands, nRows, nCols) = img.shape 805 | spectSum = numpy.zeros((maxSegId + 1, nBands), dtype=numpy.float32) 806 | 807 | for i in range(nRows): 808 | for j in range(nCols): 809 | segid = seg[i, j] 810 | for k in range(nBands): 811 | spectSum[segid, k] += img[k, i, j] 812 | 813 | return spectSum 814 | 815 | 816 | spec = [('idx', types.uint32), ('rowcols', types.uint32[:, :])] 817 | 818 | 819 | @jitclass(spec) 820 | class RowColArray(object): 821 | """ 822 | This data structure is used to store the locations of 823 | every pixel in a given segment. It will be used for entries 824 | in a jit dictionary. This means we can quickly find all the 825 | pixels belonging to a particular segment. 826 | 827 | Attributes 828 | ---------- 829 | idx : int 830 | Index of most recently added pixel 831 | rowcols : uint32 ndarray (length, 2) 832 | Row and col numbers of pixels in the segment 833 | 834 | """ 835 | def __init__(self, length): 836 | """ 837 | Initialize the data structure 838 | 839 | Parameters 840 | ---------- 841 | length : int 842 | Number of pixels in the segment 843 | 844 | """ 845 | self.idx = 0 846 | self.rowcols = numpy.empty((length, 2), dtype=numpy.uint32) 847 | 848 | def append(self, row, col): 849 | """ 850 | Add the coordinates of a new pixel in the segment 851 | 852 | Parameters 853 | ---------- 854 | row : int 855 | Row number of pixel 856 | col : int 857 | Column number of pixel 858 | 859 | """ 860 | self.rowcols[self.idx, 0] = row 861 | self.rowcols[self.idx, 1] = col 862 | self.idx += 1 863 | 864 | def getSegmentIndices(self): 865 | """ 866 | Return the row and column numbers of the segment pixels 867 | as a tuple, suitable for indexing the image array. 868 | This supports selection of all pixels for a given segment. 869 | """ 870 | return (self.rowcols[:, 0], self.rowcols[:, 1]) 871 | 872 | 873 | if 'sphinx' not in sys.modules: 874 | RowColArray_Type = RowColArray.class_type.instance_type 875 | else: 876 | # We are running in a Sphinx documentation build, so fake this 877 | RowColArray_Type = None 878 | 879 | 880 | @njit 881 | def makeSegmentLocations(seg, segSize): 882 | """ 883 | Create a data structure to hold the locations of all pixels 884 | in all segments. 885 | 886 | Parameters 887 | ---------- 888 | seg : SegIdType ndarray (nRows, nCols) 889 | Segment ID image array 890 | segSize : int ndarray (numSegments+1, ) 891 | Counts of pixels in each segment, indexed by segment ID 892 | 893 | Returns 894 | ------- 895 | segLoc : numba.typed.Dict 896 | Indexed by segment ID number, each entry is a RowColArray 897 | object, giving the pixel coordinates of all pixels for that 898 | segment 899 | 900 | """ 901 | d = Dict.empty(key_type=types.uint32, value_type=RowColArray_Type) 902 | numSeg = len(segSize) 903 | for segid in range(MINSEGID, numSeg): 904 | numPix = segSize[segid] 905 | obj = RowColArray(numPix) 906 | d[SegIdType(segid)] = obj 907 | 908 | (nRows, nCols) = seg.shape 909 | for row in range(nRows): 910 | for col in range(nCols): 911 | segid = seg[row, col] 912 | if segid != SEGNULLVAL: 913 | d[segid].append(row, col) 914 | 915 | return d 916 | 917 | 918 | @njit 919 | def eliminateSmallSegments(seg, img, maxSegId, minSegSize, maxSpectralDiff, 920 | fourConnected, minSegId): 921 | """ 922 | Eliminate small segments. Start with smallest, and merge 923 | them into spectrally most similar neighbour. Repeat for 924 | larger segments. 925 | 926 | Modifies seg array in place. 927 | 928 | Parameters 929 | ---------- 930 | seg : SegIdType ndarray (nRows, nCols) 931 | Segment ID image array. Modified in place as segments are merged. 932 | img : Integer ndarray (nBands, nRows, nCols) 933 | Input multi-band image 934 | maxSegId : SegIdType 935 | Largest segment ID number in seg 936 | minSegSize : int 937 | Size (in pixels) of the smallest segment which will NOT 938 | be eliminated 939 | maxSpectralDiff : float 940 | Limit on how different segments can be and still be merged. 941 | It is given in the units of the spectral space of img. 942 | fourConnected : bool 943 | If True, use four-way connectedness to judge neighbours, otherwise 944 | use eight-way. 945 | minSegId : SegIdType 946 | Minimum valid segment ID number 947 | 948 | Returns 949 | ------- 950 | numEliminated : int 951 | Number of segments eliminated 952 | 953 | """ 954 | spectSum = buildSegmentSpectra(seg, img, maxSegId) 955 | segSize = makeSegSize(seg) 956 | segLoc = makeSegmentLocations(seg, segSize) 957 | 958 | # A list of the segment ID numbers to merge with. The i-th 959 | # element is the segment ID to merge segment 'i' into 960 | mergeSeg = numpy.empty((maxSegId + 1), dtype=SegIdType) 961 | mergeSeg.fill(SEGNULLVAL) 962 | 963 | # Range of seg id numbers, as SegIdType, suitable as indexes into segloc 964 | segIdRange = numpy.arange(minSegId, (maxSegId + 1), dtype=SegIdType) 965 | 966 | # Start with smallest segments, move through to just 967 | # smaller than minSegSize (i.e. minSegSize is smallest 968 | # which will NOT be eliminated) 969 | numElim = 0 970 | for targetSize in range(1, minSegSize): 971 | countTargetSize = numpy.count_nonzero(segSize == targetSize) 972 | prevCount = -1 973 | # Use multiple passes to eliminate segments of this size. A 974 | # single pass can leave segments unmerged, due to the rule about 975 | # only merging with neighbours larger than current. 976 | # Note the use of MAXPASSES, just in case, as we hate infinite loops. 977 | # A very small number can still be left unmerged, if surrounded by 978 | # null segments. 979 | (numPasses, MAXPASSES) = (0, 10) 980 | while (countTargetSize != prevCount) and (numPasses < MAXPASSES): 981 | prevCount = countTargetSize 982 | 983 | for segId in segIdRange: 984 | if segSize[segId] == targetSize: 985 | mergeSeg[segId] = findMergeSegment(segId, segLoc, 986 | seg, segSize, spectSum, maxSpectralDiff, fourConnected) 987 | 988 | # Carry out the merges found above 989 | for segId in segIdRange: 990 | if mergeSeg[segId] != SEGNULLVAL: 991 | doMerge(segId, mergeSeg[segId], seg, segSize, segLoc, 992 | spectSum) 993 | mergeSeg[segId] = SEGNULLVAL 994 | numElim += 1 995 | 996 | countTargetSize = numpy.count_nonzero(segSize == targetSize) 997 | numPasses += 1 998 | 999 | relabelSegments(seg, segSize, minSegId) 1000 | return numElim 1001 | 1002 | 1003 | @njit 1004 | def findMergeSegment(segId, segLoc, seg, segSize, spectSum, maxSpectralDiff, 1005 | fourConnected): 1006 | """ 1007 | For the given segId, find which neighboring segment it 1008 | should be merged with. The chosen merge segment is the one 1009 | which is spectrally most similar to the given one, as 1010 | measured by minimum Euclidean distance in spectral space. 1011 | 1012 | Called by eliminateSmallSegments(). 1013 | 1014 | Parameters 1015 | ---------- 1016 | segId : SegIdType 1017 | Segment ID number of segment to merge 1018 | segLoc : numba.typed.Dict 1019 | Dictionary of per-segment pixel coordinates. As computed by 1020 | makeSegmentLocations() 1021 | seg : SegIdType ndarray (nRows, nCols) 1022 | Segment ID image array 1023 | segSize : int ndarray (numSegments+1, ) 1024 | Counts of pixels in each segment, indexed by segment ID 1025 | spectSum : float32 ndarray (numSegments+1, nBands) 1026 | Sums of all pixel values. As computed by buildSegmentSpectra() 1027 | maxSpectralDiff : float 1028 | Limit on how different segments can be and still be merged. 1029 | It is given in the units of the spectral space of img 1030 | fourConnected : bool 1031 | If True, use four-way connectedness to judge neighbours, otherwise 1032 | use eight-way 1033 | 1034 | """ 1035 | bestNbrSeg = SEGNULLVAL 1036 | bestDistSqr = 0.0 # This value is never used 1037 | 1038 | (nRows, nCols) = seg.shape 1039 | segRowcols = segLoc[segId].rowcols 1040 | numPix = len(segRowcols) 1041 | # Mean spectral bands 1042 | spect = spectSum[segId] / numPix 1043 | 1044 | for k in range(numPix): 1045 | (i, j) = segRowcols[k] 1046 | for ii in range(max(i - 1, 0), min(i + 2, nRows)): 1047 | for jj in range(max(j - 1, 0), min(j + 2, nCols)): 1048 | connected = (not fourConnected) or (ii == i or jj == j) 1049 | nbrSegId = seg[ii, jj] 1050 | if (connected and (nbrSegId != segId) and 1051 | (nbrSegId != SEGNULLVAL) and 1052 | (segSize[nbrSegId] > segSize[segId])): 1053 | nbrSpect = spectSum[nbrSegId] / segSize[nbrSegId] 1054 | 1055 | distSqr = ((spect - nbrSpect) ** 2).sum() 1056 | if ((bestNbrSeg == SEGNULLVAL) or (distSqr < bestDistSqr)): 1057 | bestDistSqr = distSqr 1058 | bestNbrSeg = nbrSegId 1059 | 1060 | if bestDistSqr > maxSpectralDiff**2: 1061 | bestNbrSeg = SEGNULLVAL 1062 | 1063 | return bestNbrSeg 1064 | 1065 | 1066 | @njit 1067 | def doMerge(segId, nbrSegId, seg, segSize, segLoc, spectSum): 1068 | """ 1069 | Carry out a single segment merge. The segId segment is merged to the 1070 | neighbouring nbrSegId. Modifies seg, segSize, segLoc and 1071 | spectSum in place. 1072 | 1073 | Parameters 1074 | ---------- 1075 | segId : SegIdType 1076 | Segment ID of the segment to be merged. Modified in place 1077 | nbrSegId : SegIdType 1078 | Segment ID of the segment into which segId will be merged. 1079 | Modified in place 1080 | seg : SegIdType ndarray (nRows, nCols) 1081 | Segment ID image array 1082 | segSize : int ndarray (numSegments+1, ) 1083 | Counts of pixels in each segment, indexed by segment ID. Modified 1084 | in place with new counts for both segments 1085 | segLoc : numba.typed.Dict 1086 | Dictionary of per-segment pixel coordinates. As computed by 1087 | makeSegmentLocations() 1088 | spectSum : float32 ndarray (numSegments+1, nBands) 1089 | Sums of all pixel values. As computed by buildSegmentSpectra(). 1090 | Updated in place with new sums for both segments 1091 | 1092 | """ 1093 | segRowcols = segLoc[segId].rowcols 1094 | numPix = len(segRowcols) 1095 | nbrSegRowcols = segLoc[nbrSegId].rowcols 1096 | nbrNumPix = len(nbrSegRowcols) 1097 | mergedNumPix = numPix + nbrNumPix 1098 | 1099 | # New segLoc entry for merged segment 1100 | mergedSegLoc = RowColArray(mergedNumPix) 1101 | # Copy over the existing rowcols 1102 | for k in range(nbrNumPix): 1103 | (r, c) = nbrSegRowcols[k] 1104 | mergedSegLoc.append(r, c) 1105 | 1106 | # Append the segment being merged 1107 | for k in range(numPix): 1108 | (r, c) = segRowcols[k] 1109 | seg[r, c] = nbrSegId 1110 | mergedSegLoc.append(r, c) 1111 | 1112 | # Replace the previous segLoc entry, and delete the one we merged 1113 | segLoc[nbrSegId] = mergedSegLoc 1114 | segLoc.pop(segId) 1115 | 1116 | # Update the spectral sums for the two segments 1117 | numBands = spectSum.shape[1] 1118 | for m in range(numBands): 1119 | spectSum[nbrSegId, m] += spectSum[segId, m] 1120 | spectSum[segId, m] = 0 1121 | # Update the segment sizes 1122 | segSize[nbrSegId] += segSize[segId] 1123 | segSize[segId] = 0 1124 | -------------------------------------------------------------------------------- /pyshepseg/subset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains functionality for subsetting a large 3 | segmented image into a smaller one for checking etc. See 4 | :func:`subsetImage`. 5 | 6 | """ 7 | # Copyright 2021 Neil Flood and Sam Gillingham. All rights reserved. 8 | # 9 | # Permission is hereby granted, free of charge, to any person 10 | # obtaining a copy of this software and associated documentation 11 | # files (the "Software"), to deal in the Software without restriction, 12 | # including without limitation the rights to use, copy, modify, 13 | # merge, publish, distribute, sublicense, and/or sell copies of the 14 | # Software, and to permit persons to whom the Software is furnished 15 | # to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be 18 | # included in all copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 21 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 22 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 23 | # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR 24 | # ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 25 | # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 26 | # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 27 | 28 | import numpy 29 | from numba import njit 30 | from numba.typed import Dict 31 | from osgeo import gdal 32 | 33 | from . import shepseg 34 | from . import tiling 35 | from . import tilingstats 36 | 37 | gdal.UseExceptions() 38 | 39 | 40 | def subsetImage(inname, outname, tlx, tly, newXsize, newYsize, outformat, 41 | creationOptions=[], origSegIdColName=None, maskImage=None): 42 | """ 43 | Subset an image and "compress" the RAT so only values that 44 | are in the new image are in the RAT. Note that the image values 45 | will be recoded in the process. 46 | 47 | gdal_translate seems to have a problem with files that 48 | have large RAT's so while that gets fixed do the subsetting 49 | in this function. 50 | 51 | Parameters 52 | ---------- 53 | inname : str or gdal.Dataset 54 | Filename of input raster, or open gdal Dataset object 55 | outname : str 56 | Filename of output raster 57 | tlx, tly : int 58 | The x & y pixel coordinates (i.e. col & row, respectively) of the 59 | top left pixel of the image subset to extract. 60 | newXsize, newYsize : int 61 | The x & y size (in pixels) of the image subset to extract 62 | outformat : str 63 | The GDAL short name of the format driver to use for the output 64 | raster file 65 | creationOptions : list of str 66 | The GDAL creation options for the output file 67 | origSegIdColName : str or None 68 | The name of a RAT column. If not None, this will be created 69 | in the output file, and will contain the original segment ID 70 | numbers so the new segment IDs can be linked back to the old 71 | maskImage : str or None 72 | If not None, then the filename of a mask raster. Only pixels 73 | which are non-zero in this mask image will be included in the 74 | subset. This image is assumed to match the shape and position 75 | of the output subset 76 | 77 | """ 78 | if isinstance(inname, gdal.Dataset): 79 | inds = inname 80 | else: 81 | inds = gdal.Open(inname, gdal.GA_Update) 82 | 83 | inband = inds.GetRasterBand(1) 84 | 85 | if (tlx + newXsize) > inband.XSize or (tly + newYsize) > inband.YSize: 86 | msg = 'Requested subset is not within input image' 87 | raise PyShepSegSubsetError(msg) 88 | 89 | driver = gdal.GetDriverByName(outformat) 90 | outds = driver.Create(outname, newXsize, newYsize, 1, inband.DataType, 91 | options=creationOptions) 92 | # set the output projection and transform 93 | outds.SetProjection(inds.GetProjection()) 94 | transform = list(inds.GetGeoTransform()) 95 | transform[0] = transform[0] + transform[1] * tlx 96 | transform[3] = transform[3] + transform[5] * tly 97 | outds.SetGeoTransform(transform) 98 | 99 | outband = outds.GetRasterBand(1) 100 | outband.SetMetadataItem('LAYER_TYPE', 'thematic') 101 | outRAT = outband.GetDefaultRAT() 102 | 103 | inRAT = inband.GetDefaultRAT() 104 | recodeDict = Dict.empty(key_type=tiling.segIdNumbaType, 105 | value_type=tiling.segIdNumbaType) # keyed on original ID - value is new row ID 106 | histogramDict = Dict.empty(key_type=tiling.segIdNumbaType, 107 | value_type=tiling.segIdNumbaType) # keyed on new ID - value is count 108 | 109 | # make the output file has the same columns as the input 110 | numIntCols, numFloatCols = copyColumns(inRAT, outRAT) 111 | 112 | # If a maskImage was specified then open it 113 | maskds = None 114 | maskBand = None 115 | maskData = None 116 | if maskImage is not None: 117 | maskds = gdal.Open(maskImage) 118 | maskBand = maskds.GetRasterBand(1) 119 | if maskBand.XSize != newXsize or maskBand.YSize != newYsize: 120 | msg = 'mask should match requested subset size if supplied' 121 | raise PyShepSegSubsetError(msg) 122 | 123 | # work out how many tiles we have 124 | tileSize = tiling.TILESIZE 125 | numXtiles = int(numpy.ceil(newXsize / tileSize)) 126 | numYtiles = int(numpy.ceil(newYsize / tileSize)) 127 | 128 | minInVal = None 129 | maxInVal = None 130 | 131 | for tileRow in range(numYtiles): 132 | for tileCol in range(numXtiles): 133 | leftPix = tlx + tileCol * tileSize 134 | topLine = tly + tileRow * tileSize 135 | xsize = min(tileSize, newXsize - leftPix + tlx) 136 | ysize = min(tileSize, newYsize - topLine + tly) 137 | 138 | # extract the image data for this tile from the input file 139 | inData = inband.ReadAsArray(leftPix, topLine, xsize, ysize) 140 | 141 | # work out the range of data for accessing the whole RAT (below) 142 | inDataMasked = inData[inData != shepseg.SEGNULLVAL] 143 | if len(inDataMasked) == 0: 144 | # no actual data in this tile 145 | continue 146 | 147 | minVal = inDataMasked.min() 148 | maxVal = inDataMasked.max() 149 | if minInVal is None or minVal < minInVal: 150 | minInVal = minVal 151 | if maxInVal is None or maxVal > maxInVal: 152 | maxInVal = maxVal 153 | 154 | if maskBand is not None: 155 | # if a mask file was specified read the corresoponding data 156 | maskData = maskBand.ReadAsArray(tileCol * tileSize, 157 | tileRow * tileSize, xsize, ysize) 158 | 159 | # process this tile, obtaining the image of the 'new' segment ids 160 | # and updating recodeDict as we go 161 | outData = processSubsetTile(inData, recodeDict, 162 | histogramDict, maskData) 163 | 164 | # write out the new segment ids to the output 165 | outband.WriteArray(outData, tileCol * tileSize, tileRow * tileSize) 166 | 167 | if minInVal is None or maxInVal is None: 168 | # must be all shepseg.SEGNULLVAL 169 | raise PyShepSegSubsetError('No valid data found in subset') 170 | 171 | # process the recodeDict, one page of the input at a time 172 | 173 | # fill this in as we go and write out each page when complete. 174 | outPagedRat = tilingstats.createPagedRat() 175 | for startSegId in range(minInVal, maxInVal, tilingstats.RAT_PAGE_SIZE): 176 | # looping through in tilingstats.RAT_PAGE_SIZE pages 177 | endSegId = min(startSegId + tilingstats.RAT_PAGE_SIZE - 1, maxInVal) 178 | 179 | # get this page in 180 | inPage = readRATIntoPage(inRAT, numIntCols, numFloatCols, 181 | startSegId, endSegId) 182 | 183 | # copy any in recodeDict into the new outPagedRat 184 | copySubsettedSegmentsToNew(inPage, outPagedRat, recodeDict) 185 | 186 | writeCompletedPagesForSubset(inRAT, outRAT, outPagedRat) 187 | 188 | # write out the histogram we've been updating 189 | histArray = numpy.empty(outRAT.GetRowCount(), dtype=numpy.float64) 190 | setHistogramFromDictionary(histogramDict, histArray) 191 | 192 | colNum = outRAT.GetColOfUsage(gdal.GFU_PixelCount) 193 | if colNum == -1: 194 | outRAT.CreateColumn('Histogram', gdal.GFT_Real, gdal.GFU_PixelCount) 195 | colNum = outRAT.GetColumnCount() - 1 196 | outRAT.WriteArray(histArray, colNum) 197 | del histArray 198 | 199 | # optional column with old segids 200 | if origSegIdColName is not None: 201 | # find or create column 202 | colNum = -1 203 | for n in range(outRAT.GetColumnCount()): 204 | if outRAT.GetNameOfCol(n) == origSegIdColName: 205 | colNum = n 206 | break 207 | 208 | if colNum == -1: 209 | outRAT.CreateColumn(origSegIdColName, gdal.GFT_Integer, 210 | gdal.GFU_Generic) 211 | colNum = outRAT.GetColumnCount() - 1 212 | 213 | origSegIdArray = numpy.empty(outRAT.GetRowCount(), dtype=numpy.int32) 214 | setSubsetRecodeFromDictionary(recodeDict, origSegIdArray) 215 | outRAT.WriteArray(origSegIdArray, colNum) 216 | 217 | 218 | @njit 219 | def copySubsettedSegmentsToNew(inPage, outPagedRat, recodeDict): 220 | """ 221 | Using the recodeDict, copy across the rows inPage to outPage. 222 | 223 | inPage is processed and (taking into account of inPage.startSegId) 224 | the original input row found. This value is then 225 | looked up in recodeDict to find the row in the output RAT to 226 | copy the row from the input to. 227 | 228 | Parameters 229 | ---------- 230 | inPage : tilingstats.RatPage 231 | A page of RAT from the input file 232 | outPagedRat : numba.typed.Dict 233 | In-memory pages of the output RAT, as created by createPagedRat(). 234 | This is modified in-place, creating new pages as required. 235 | recodeDict : numba.typed.Dict 236 | Keyed by original segment ID, values are the corresponding 237 | segment IDs in the subset 238 | 239 | """ 240 | numIntCols = inPage.intcols.shape[0] 241 | numFloatCols = inPage.floatcols.shape[0] 242 | maxSegId = len(recodeDict) 243 | for inRowInPage in range(inPage.intcols.shape[1]): 244 | inRow = tiling.segIdNumbaType(inPage.startSegId + inRowInPage) 245 | if inRow not in recodeDict: 246 | # this one is not in this subset, skip 247 | continue 248 | outRow = recodeDict[inRow] 249 | 250 | outPageId = tilingstats.getRatPageId(outRow) 251 | outRowInPage = outRow - outPageId 252 | if outPageId not in outPagedRat: 253 | numSegThisPage = min(tilingstats.RAT_PAGE_SIZE, (maxSegId - outPageId + 1)) 254 | outPagedRat[outPageId] = tilingstats.RatPage(numIntCols, numFloatCols, 255 | outPageId, numSegThisPage) 256 | if outPageId == shepseg.SEGNULLVAL: 257 | # nothing will get written to this one, but needs to be 258 | # marked as complete so whole page will be written 259 | outPagedRat[outPageId].setSegmentComplete(shepseg.SEGNULLVAL) 260 | 261 | outPage = outPagedRat[outPageId] 262 | for n in range(numIntCols): 263 | outPage.intcols[n, outRowInPage] = inPage.intcols[n, inRowInPage] 264 | for n in range(numFloatCols): 265 | outPage.floatcols[n, outRowInPage] = inPage.floatcols[n, inRowInPage] 266 | 267 | # we mark this as complete as we have copied the row over. 268 | outPage.setSegmentComplete(outRow) 269 | 270 | 271 | @njit 272 | def setHistogramFromDictionary(dictn, histArray): 273 | """ 274 | Given a dictionary of pixel counts keyed on index, 275 | write these values to the array. 276 | """ 277 | for idx in dictn: 278 | histArray[idx] = dictn[idx] 279 | histArray[shepseg.SEGNULLVAL] = 0 280 | 281 | 282 | @njit 283 | def setSubsetRecodeFromDictionary(dictn, array): 284 | """ 285 | Given the recodeDict write the original values to the array 286 | at the new indices. 287 | """ 288 | for idx in dictn: 289 | array[dictn[idx]] = idx 290 | array[shepseg.SEGNULLVAL] = 0 291 | 292 | 293 | @njit 294 | def readColDataIntoPage(page, data, idx, colType, minVal): 295 | """ 296 | Numba function to quickly read a column returned by 297 | rat.ReadAsArray() info a RatPage. 298 | """ 299 | for i in range(data.shape[0]): 300 | page.setRatVal(i + minVal, colType, idx, data[i]) 301 | 302 | 303 | def readRATIntoPage(rat, numIntCols, numFloatCols, minVal, maxVal): 304 | """ 305 | Create a new RatPage() object that represents the section of the RAT 306 | for a tile of an image. The part of the RAT between minVal and maxVal 307 | is read in and a RatPage() instance returned with the startSegId param 308 | set to minVal. 309 | 310 | """ 311 | minVal = int(minVal) 312 | nrows = int(maxVal - minVal) + 1 313 | page = tilingstats.RatPage(numIntCols, numFloatCols, minVal, nrows) 314 | 315 | intColIdx = 0 316 | floatColIdx = 0 317 | for col in range(rat.GetColumnCount()): 318 | dtype = rat.GetTypeOfCol(col) 319 | data = rat.ReadAsArray(col, start=minVal, length=nrows) 320 | if dtype == gdal.GFT_Integer: 321 | readColDataIntoPage(page, data, intColIdx, 322 | tilingstats.STAT_DTYPE_INT, minVal) 323 | intColIdx += 1 324 | else: 325 | readColDataIntoPage(page, data, floatColIdx, 326 | tilingstats.STAT_DTYPE_FLOAT, minVal) 327 | floatColIdx += 1 328 | 329 | return page 330 | 331 | 332 | def copyColumns(inRat, outRat): 333 | """ 334 | Copy column structure from inRat to outRat. Note that this just creates 335 | the empty columns in outRat, it does not copy any data. 336 | 337 | Parameters 338 | ---------- 339 | inRat, outRat : gdal.RasterAttributeTable 340 | Columns found in inRat are created on outRat 341 | 342 | Returns 343 | ------- 344 | numIntCols : int 345 | Number of integer columns found 346 | numFloatCols : int 347 | Number of float columns found 348 | 349 | """ 350 | numIntCols = 0 351 | numFloatCols = 0 352 | for col in range(inRat.GetColumnCount()): 353 | dtype = inRat.GetTypeOfCol(col) 354 | usage = inRat.GetUsageOfCol(col) 355 | name = inRat.GetNameOfCol(col) 356 | outRat.CreateColumn(name, dtype, usage) 357 | if dtype == gdal.GFT_Integer: 358 | numIntCols += 1 359 | elif dtype == gdal.GFT_Real: 360 | numFloatCols += 1 361 | else: 362 | raise TypeError("String columns not supported") 363 | 364 | return numIntCols, numFloatCols 365 | 366 | 367 | @njit 368 | def processSubsetTile(tile, recodeDict, histogramDict, maskData): 369 | """ 370 | Process a tile of the subset area. Returns a new tile with the new codes. 371 | Fills in the recodeDict as it goes and also updates histogramDict. 372 | 373 | Parameters 374 | ---------- 375 | tile : shepseg.SegIdType ndarray (tileNrows, tileNcols) 376 | Input tile of segment IDs 377 | recodeDict : numba.typed.Dict 378 | Keyed by original segment ID, values are the corresponding 379 | segment IDs in the subset 380 | histogramDict : numba.typed.Dict 381 | Histogram counts in the subset, keyed by new segment ID 382 | maskData : None or int ndarray (tileNrows, tileNcols) 383 | If not None, then is a raster mask. Only pixels which are 384 | non-zero in the mask will be included in the subset 385 | 386 | Returns 387 | ------- 388 | outData : shepseg.SegIdType ndarray (tileNrows, tileNcols) 389 | Recoded copy of the input tile. 390 | 391 | """ 392 | outData = numpy.zeros_like(tile) 393 | 394 | ysize, xsize = tile.shape 395 | # go through each pixel 396 | for y in range(ysize): 397 | for x in range(xsize): 398 | segId = tile[y, x] 399 | if maskData is not None and maskData[y, x] == 0: 400 | # if this one is masked out - skip 401 | outData[y, x] = shepseg.SEGNULLVAL 402 | continue 403 | 404 | if segId == shepseg.SEGNULLVAL: 405 | # null segment - skip 406 | outData[y, x] = shepseg.SEGNULLVAL 407 | continue 408 | 409 | if segId not in recodeDict: 410 | # haven't encountered this pixel before, generate the new id for it 411 | outSegId = len(recodeDict) + 1 412 | 413 | # add this new value to our recode dictionary 414 | recodeDict[segId] = tiling.segIdNumbaType(outSegId) 415 | 416 | # write this new value to the output image 417 | newval = recodeDict[segId] 418 | outData[y, x] = newval 419 | # update histogram 420 | if newval not in histogramDict: 421 | histogramDict[newval] = tiling.segIdNumbaType(0) 422 | histogramDict[newval] = tiling.segIdNumbaType(histogramDict[newval] + 1) 423 | 424 | return outData 425 | 426 | 427 | def writeCompletedPagesForSubset(inRAT, outRAT, outPagedRat): 428 | """ 429 | For the subset operation. Writes out any completed pages to outRAT 430 | using the inRAT as a template. 431 | 432 | Parameters 433 | ---------- 434 | inRAT, outRAT : gdal.RasterAttributeTable 435 | The input and output raster attribute tables. 436 | outPagedRat : numba.typed.Dict 437 | The paged RAT in memory, as created by createPagedRat() 438 | 439 | """ 440 | # Make an array of the pageId values, with the correct type (SegIdType) 441 | pagedRatKeys = numpy.empty(len(outPagedRat), dtype=shepseg.SegIdType) 442 | i = 0 443 | for pageId in outPagedRat: 444 | pagedRatKeys[i] = pageId 445 | i += 1 446 | 447 | for pageId in pagedRatKeys: 448 | ratPage = outPagedRat[pageId] 449 | if ratPage.pageComplete(): 450 | # this one one is complete. Grow RAT if required 451 | maxRow = ratPage.startSegId + ratPage.intcols.shape[1] 452 | if outRAT.GetRowCount() < maxRow: 453 | outRAT.SetRowCount(maxRow) 454 | 455 | # loop through the input RAT, using the type info 456 | # of each column to decide intcols/floatcols etc 457 | intColIdx = 0 458 | floatColIdx = 0 459 | for col in range(inRAT.GetColumnCount()): 460 | dtype = inRAT.GetTypeOfCol(col) 461 | if dtype == gdal.GFT_Integer: 462 | data = ratPage.intcols[intColIdx] 463 | intColIdx += 1 464 | else: 465 | data = ratPage.floatcols[floatColIdx] 466 | floatColIdx += 1 467 | 468 | outRAT.WriteArray(data, col, start=ratPage.startSegId) 469 | 470 | # this one is done 471 | outPagedRat.pop(pageId) 472 | 473 | 474 | class PyShepSegSubsetError(Exception): 475 | pass 476 | -------------------------------------------------------------------------------- /pyshepseg/timinghooks.py: -------------------------------------------------------------------------------- 1 | """ 2 | A class to support placement of timing points in a Python application. 3 | """ 4 | import time 5 | import threading 6 | import contextlib 7 | import unittest 8 | 9 | try: 10 | import numpy 11 | except ImportError: 12 | numpy = None 13 | 14 | 15 | __version__ = "1.0.0" 16 | 17 | 18 | class Timers: 19 | """ 20 | Manage multiple named timers. See interval() method for example 21 | usage. The makeSummaryDict() method can be used to generate 22 | summary statistics on the timings. 23 | 24 | Maintains a dictionary of pairs of start/finish times, before and 25 | after particular operations. These are grouped by operation names, 26 | and for each name, a list is accumulated of the pairs, for every 27 | time when this operation was carried out. 28 | 29 | The operation names are arbitrary strings chosen by the user at each 30 | point where a timer is embedded in the application code. 31 | 32 | Timing intervals can be safely nested, so some intervals can be 33 | contained with others. 34 | 35 | The object is thread-safe, so multiple threads can accumulate to 36 | the same names. The object is also pickle-able. 37 | 38 | Example Usage:: 39 | 40 | timings = Timers() 41 | with timings.interval('walltime'): 42 | for i in range(count): 43 | # Some code with no specific timer 44 | 45 | with timings.interval('reading'): 46 | # Code to do reading operations 47 | 48 | with timings.interval('computation'): 49 | # Code to do computation 50 | 51 | summary = timings.makeSummaryDict() 52 | print(summary) 53 | 54 | The resulting summary dictionary would be something like:: 55 | 56 | {'walltime': ['total': 12.345, 'min': 1.234, ......], 57 | 'reading': ['total': 3.456, 'min': 0.234, ......], 58 | 'computation': ['total': 9.123, 'min': 1.012, ......] 59 | } 60 | 61 | All times are in seconds. 62 | 63 | These 'with interval' blocks can be scattered through an application's 64 | code, all using the same timings object. The summary dictionary can be 65 | used to generate a report at the end of the application to present to a 66 | user, showing how the key parts of the application compare in time taken. 67 | 68 | """ 69 | def __init__(self): 70 | self.pairs = {} 71 | self.lock = threading.Lock() 72 | 73 | @contextlib.contextmanager 74 | def interval(self, intervalName): 75 | """ 76 | Use as a context manager to time a particular named interval. 77 | 78 | Example:: 79 | 80 | timings = Timers() 81 | with timings.interval('some_action'): 82 | # Code block required to perform the action 83 | 84 | After exit from the `with` statement, the timings object will have 85 | accumulated the start and end times around the code block. These 86 | will then contribute to the reporting of time intervals. 87 | 88 | """ 89 | startTime = time.time() 90 | yield 91 | endTime = time.time() 92 | with self.lock: 93 | if intervalName not in self.pairs: 94 | self.pairs[intervalName] = [] 95 | self.pairs[intervalName].append((startTime, endTime)) 96 | 97 | def getDurationsForName(self, intervalName): 98 | """ 99 | For the given interval name, turns that list of start/end times 100 | into a list of durations (end - start), in seconds. 101 | 102 | Returns the list of durations. 103 | """ 104 | if intervalName in self.pairs: 105 | intervals = [(p[1] - p[0]) for p in self.pairs[intervalName]] 106 | else: 107 | intervals = None 108 | return intervals 109 | 110 | def merge(self, other): 111 | """ 112 | Merge another Timers object into this one 113 | """ 114 | with self.lock: 115 | for intervalName in other.pairs: 116 | if intervalName in self.pairs: 117 | self.pairs[intervalName].extend(other.pairs[intervalName]) 118 | else: 119 | self.pairs[intervalName] = other.pairs[intervalName] 120 | 121 | def makeSummaryDict(self): 122 | """ 123 | Make some summary statistics, and return them in a dictionary 124 | """ 125 | if numpy is None: 126 | print("Timers.makeSummaryDict() requires numpy") 127 | return 128 | 129 | d = {} 130 | for name in self.pairs: 131 | intervals = numpy.array(self.getDurationsForName(name)) 132 | tot = float(intervals.sum()) 133 | minVal = float(intervals.min()) 134 | maxVal = float(intervals.max()) 135 | meanVal = float(intervals.mean()) 136 | pcnt25 = float(numpy.percentile(intervals, 25)) 137 | pcnt50 = float(numpy.percentile(intervals, 50)) 138 | pcnt75 = float(numpy.percentile(intervals, 75)) 139 | d[name] = {'total': tot, 'min': minVal, 'max': maxVal, 140 | 'lowerq': pcnt25, 'median': pcnt50, 'upperq': pcnt75, 141 | 'mean': meanVal, 'count': len(intervals)} 142 | return d 143 | 144 | def __getstate__(self): 145 | """ 146 | Ensure pickleability by omitting the lock attribute 147 | """ 148 | d = {} 149 | with self.lock: 150 | d.update(self.__dict__) 151 | d.pop('lock') 152 | return d 153 | 154 | def __setstate__(self, state): 155 | """ 156 | For unpickling, add a new lock attribute 157 | """ 158 | self.lock = threading.Lock() 159 | with self.lock: 160 | self.__dict__.update(state) 161 | 162 | 163 | class AllTests(unittest.TestCase): 164 | """ 165 | Run all tests 166 | """ 167 | places = 2 168 | 169 | def test_single(self): 170 | t = Timers() 171 | with t.interval('test1'): 172 | time.sleep(2) 173 | summ = t.makeSummaryDict() 174 | self.assertAlmostEqual(summ['test1']['total'], 2, places=self.places) 175 | 176 | def test_multiple(self): 177 | t = Timers() 178 | with t.interval('test2'): 179 | time.sleep(1) 180 | 181 | with t.interval('test2'): 182 | time.sleep(2) 183 | 184 | with t.interval('test3'): 185 | time.sleep(2.5) 186 | 187 | summ = t.makeSummaryDict() 188 | self.assertAlmostEqual(summ['test2']['total'], 3, places=self.places) 189 | self.assertAlmostEqual(summ['test3']['total'], 2.5, places=self.places) 190 | 191 | def test_nested(self): 192 | t = Timers() 193 | with t.interval('test1'): 194 | time.sleep(1) 195 | with t.interval('test2'): 196 | time.sleep(2) 197 | 198 | summ = t.makeSummaryDict() 199 | self.assertAlmostEqual(summ['test1']['total'], 3, places=self.places) 200 | self.assertAlmostEqual(summ['test2']['total'], 2, places=self.places) 201 | 202 | 203 | def mainCmd(): 204 | unittest.main(module='timinghooks') 205 | 206 | 207 | if __name__ == "__main__": 208 | mainCmd() 209 | -------------------------------------------------------------------------------- /pyshepseg/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Utility functions for working with segmented data. 4 | 5 | """ 6 | 7 | # Copyright 2021 Neil Flood and Sam Gillingham. All rights reserved. 8 | # 9 | # Permission is hereby granted, free of charge, to any person 10 | # obtaining a copy of this software and associated documentation 11 | # files (the "Software"), to deal in the Software without restriction, 12 | # including without limitation the rights to use, copy, modify, 13 | # merge, publish, distribute, sublicense, and/or sell copies of the 14 | # Software, and to permit persons to whom the Software is furnished 15 | # to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be 18 | # included in all copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 21 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 22 | # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 23 | # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR 24 | # ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 25 | # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 26 | # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 27 | 28 | # Just in case anyone is trying to use this with Python-2 29 | from __future__ import print_function, division 30 | 31 | import sys 32 | import inspect 33 | import traceback 34 | 35 | import numpy 36 | from . import shepseg 37 | 38 | from osgeo import gdal 39 | 40 | gdal.UseExceptions() 41 | 42 | DEFAULT_MINOVERVIEWDIM = 100 43 | DEFAULT_OVERVIEWLEVELS = [4, 8, 16, 32, 64, 128, 256, 512] 44 | gdalFloatTypes = set([gdal.GDT_Float32, gdal.GDT_Float64]) 45 | 46 | 47 | def estimateStatsFromHisto(bandObj, hist): 48 | """ 49 | As a shortcut to calculating stats with GDAL, use the histogram 50 | that we already have from calculating the RAT and calc the stats 51 | from that. 52 | """ 53 | # https://stackoverflow.com/questions/47269390/numpy-how-to-find-first-non-zero-value-in-every-column-of-a-numpy-array 54 | mask = hist > 0 55 | nVals = hist.sum() 56 | minVal = mask.argmax() 57 | maxVal = hist.shape[0] - numpy.flip(mask).argmax() - 1 58 | 59 | values = numpy.arange(hist.shape[0]) 60 | 61 | meanVal = (values * hist).sum() / nVals 62 | 63 | stdDevVal = (hist * numpy.power(values - meanVal, 2)).sum() / nVals 64 | stdDevVal = numpy.sqrt(stdDevVal) 65 | 66 | modeVal = numpy.argmax(hist) 67 | # estimate the median - bin with the middle number 68 | middlenum = hist.sum() / 2 69 | gtmiddle = hist.cumsum() >= middlenum 70 | medianVal = gtmiddle.nonzero()[0][0] 71 | 72 | if bandObj.DataType in gdalFloatTypes: 73 | # convert values away from numpy scalars as they have a repr() 74 | # in the form np.float... Band is float so conver to floats 75 | minVal = float(minVal) 76 | maxVal = float(maxVal) 77 | modeVal = float(modeVal) 78 | medianVal = float(medianVal) 79 | else: 80 | # convert to ints 81 | minVal = int(minVal) 82 | maxVal = int(maxVal) 83 | modeVal = int(modeVal) 84 | medianVal = int(medianVal) 85 | # mean and standard deviation stay as floats 86 | 87 | bandObj.SetMetadataItem("STATISTICS_MINIMUM", repr(minVal)) 88 | bandObj.SetMetadataItem("STATISTICS_MAXIMUM", repr(maxVal)) 89 | bandObj.SetMetadataItem("STATISTICS_MEAN", repr(float(meanVal))) 90 | bandObj.SetMetadataItem("STATISTICS_STDDEV", repr(float(stdDevVal))) 91 | bandObj.SetMetadataItem("STATISTICS_MODE", repr(modeVal)) 92 | bandObj.SetMetadataItem("STATISTICS_MEDIAN", repr(medianVal)) 93 | bandObj.SetMetadataItem("STATISTICS_SKIPFACTORX", "1") 94 | bandObj.SetMetadataItem("STATISTICS_SKIPFACTORY", "1") 95 | bandObj.SetMetadataItem("STATISTICS_HISTOBINFUNCTION", "direct") 96 | 97 | 98 | def addOverviews(ds): 99 | """ 100 | Add raster overviews to the given file. 101 | Mimic rios.calcstats behaviour to decide how many overviews. 102 | 103 | Parameters 104 | ---------- 105 | ds : gdal.Dataset 106 | Open Dataset for the raster file 107 | 108 | """ 109 | # first we work out how many overviews to build based on the size 110 | if ds.RasterXSize < ds.RasterYSize: 111 | mindim = ds.RasterXSize 112 | else: 113 | mindim = ds.RasterYSize 114 | 115 | nOverviews = 0 116 | for i in DEFAULT_OVERVIEWLEVELS: 117 | if (mindim // i) > DEFAULT_MINOVERVIEWDIM: 118 | nOverviews = nOverviews + 1 119 | 120 | ds.BuildOverviews("NEAREST", DEFAULT_OVERVIEWLEVELS[:nOverviews]) 121 | 122 | 123 | def writeRandomColourTable(outBand, nRows): 124 | """ 125 | Attach a randomly-generated colour table to the given segmentation 126 | image. Mainly useful so the segmentation boundaries can be viewed, 127 | without any regard to the meaning of the segments. 128 | 129 | Parameters 130 | ---------- 131 | outBand : gdal.Band 132 | Open GDAL Band object for the segmentation image 133 | nRows : int 134 | Number of rows in the attribute table, equal to the 135 | number of segments + 1. 136 | 137 | """ 138 | nRows = int(nRows) 139 | colNames = ["Blue", "Green", "Red"] 140 | colUsages = [gdal.GFU_Blue, gdal.GFU_Green, gdal.GFU_Red] 141 | 142 | attrTbl = outBand.GetDefaultRAT() 143 | attrTbl.SetRowCount(nRows) 144 | 145 | for band in range(3): 146 | colNum = attrTbl.GetColOfUsage(colUsages[band]) 147 | if colNum == -1: 148 | attrTbl.CreateColumn(colNames[band], gdal.GFT_Integer, colUsages[band]) 149 | colNum = attrTbl.GetColumnCount() - 1 150 | colour = numpy.random.random_integers(0, 255, size=nRows) 151 | attrTbl.WriteArray(colour, colNum) 152 | 153 | alpha = numpy.full((nRows,), 255, dtype=numpy.uint8) 154 | alpha[shepseg.SEGNULLVAL] = 0 155 | colNum = attrTbl.GetColOfUsage(gdal.GFU_Alpha) 156 | if colNum == -1: 157 | attrTbl.CreateColumn('Alpha', gdal.GFT_Integer, gdal.GFU_Alpha) 158 | colNum = attrTbl.GetColumnCount() - 1 159 | attrTbl.WriteArray(alpha, colNum) 160 | 161 | 162 | def writeColorTableFromRatColumns(segfile, redColName, greenColName, 163 | blueColName): 164 | """ 165 | Use the values in the given columns in the raster attribute 166 | table (RAT) to create corresponding color table columns, so that 167 | the segmented image will display similarly to same bands of the 168 | the original image. 169 | 170 | The general idea is that the given columns would be the per-segment 171 | mean values of the desired bands (see tiling.calcPerSegmentStatsTiled() 172 | to create such columns). 173 | 174 | Parameters 175 | ---------- 176 | segfile : str or gdal.Dataset 177 | Filename of the completed segmentation image, with RAT columns 178 | already written. Can be either the file name string, or 179 | an open Dataset object. 180 | redColName : str 181 | Name of the column in the RAT to use for the red color 182 | greenColName : str 183 | Name of the column in the RAT to use for the green color 184 | blueColName : str 185 | Name of the column in the RAT to use for the blue color 186 | 187 | """ 188 | colList = [redColName, greenColName, blueColName] 189 | colorColList = ['Red', 'Green', 'Blue'] 190 | usageList = [gdal.GFU_Red, gdal.GFU_Green, gdal.GFU_Blue] 191 | 192 | if isinstance(segfile, gdal.Dataset): 193 | ds = segfile 194 | else: 195 | ds = gdal.Open(segfile, gdal.GA_Update) 196 | 197 | band = ds.GetRasterBand(1) 198 | attrTbl = band.GetDefaultRAT() 199 | colNameList = [attrTbl.GetNameOfCol(i) 200 | for i in range(attrTbl.GetColumnCount())] 201 | 202 | for i in range(3): 203 | n = colNameList.index(colList[i]) 204 | colVals = attrTbl.ReadAsArray(n) 205 | 206 | # If the corresponding color column does not yet exist, then create it 207 | if colorColList[i] not in colNameList: 208 | attrTbl.CreateColumn(colorColList[i], gdal.GFT_Integer, usageList[i]) 209 | clrColNdx = attrTbl.GetColumnCount() - 1 210 | else: 211 | clrColNdx = colNameList.index(colorColList[i]) 212 | 213 | # Use the column values to create a color column of values in 214 | # the range 0-255. Stretch to the 5-th and 95th percentiles, to 215 | # avoid extreme values causing washed out colors. 216 | colMin = numpy.percentile(colVals, 5) 217 | colMax = numpy.percentile(colVals, 95) 218 | clr = (255 * ((colVals - colMin) / (colMax - colMin)).clip(0, 1)) 219 | 220 | # Write the color column 221 | attrTbl.WriteArray(clr.astype(numpy.uint8), clrColNdx) 222 | 223 | # Now write the opacity (alpha) column. Set to full opacity. 224 | alpha = numpy.full(len(colVals), 255, dtype=numpy.uint8) 225 | if 'Alpha' not in colNameList: 226 | attrTbl.CreateColumn('Alpha', gdal.GFT_Integer, gdal.GFU_Alpha) 227 | i = attrTbl.GetColumnCount() - 1 228 | else: 229 | i = colNameList.index('Alpha') 230 | attrTbl.WriteArray(alpha, i) 231 | 232 | 233 | deprecationAlreadyWarned = set() 234 | 235 | 236 | def deprecationWarning(msg, stacklevel=2): 237 | """ 238 | Print a deprecation warning to stderr. Includes the filename 239 | and line number of the call to the function which called this. 240 | The stacklevel argument controls how many stack levels above this 241 | gives the line number. 242 | 243 | Implemented in mimcry of warnings.warn(), which seems very flaky. 244 | Sometimes it prints, and sometimes not, unless PYTHONWARNINGS is set 245 | (or -W is used). This function at least seems to work consistently. 246 | 247 | """ 248 | frame = inspect.currentframe() 249 | for i in range(stacklevel): 250 | if frame is not None: 251 | frame = frame.f_back 252 | 253 | if frame is None: 254 | filename = "sys" 255 | lineno = 1 256 | else: 257 | filename = frame.f_code.co_filename 258 | lineno = frame.f_lineno 259 | 260 | key = (filename, lineno) 261 | if key not in deprecationAlreadyWarned: 262 | print("{} (line {}):\n WARNING: {}".format(filename, lineno, msg), 263 | file=sys.stderr) 264 | deprecationAlreadyWarned.add(key) 265 | 266 | 267 | class WorkerErrorRecord: 268 | """ 269 | Hold a record of an exception raised in a remote worker. 270 | """ 271 | def __init__(self, exc, workerType): 272 | self.exc = exc 273 | self.workerType = workerType 274 | self.formattedTraceback = traceback.format_exception(exc) 275 | 276 | def __str__(self): 277 | headLine = "Error in {} worker".format(self.workerType) 278 | lines = [headLine] 279 | lines.extend([line.strip('\n') for line in self.formattedTraceback]) 280 | s = '\n'.join(lines) + '\n' 281 | return s 282 | 283 | 284 | def reportWorkerException(exceptionRecord): 285 | """ 286 | Report the given WorkerExceptionRecord object to stderr 287 | """ 288 | print(exceptionRecord, file=sys.stderr) 289 | 290 | 291 | def formatTimingRpt(summaryDict): 292 | """ 293 | Format a report on timings, given the output of Timers.makeSummaryDict() 294 | 295 | Return a single string of the formatted report. 296 | """ 297 | # Make a list of individual timers, hopefully in a sensible order 298 | isSeg = ('spectralclusters' in summaryDict) 299 | isStats = ('statscompletion' in summaryDict) 300 | if isSeg: 301 | hdr = "Segmentation Timings (sec)" 302 | timerList = ['spectralclusters', 'startworkers', 'reading', 303 | 'segmentation', 'stitchtiles'] 304 | elif isStats: 305 | hdr = "Per-segment Stats Timings (sec)" 306 | timerList = ['reading', 'accumulation', 'statscompletion', 'writing'] 307 | else: 308 | # Some unknown set of timers, do something sensible 309 | hdr = "Timers (unknown set) (sec)" 310 | timerList = sorted(list(summaryDict.keys())) 311 | # Remove any which are not present in summaryDict 312 | timerList = [t for t in timerList if t in summaryDict] 313 | 314 | lines = [hdr] 315 | walltimeDict = summaryDict.get('walltime') 316 | if walltimeDict is not None: 317 | walltime = walltimeDict['total'] 318 | lines.append(f"Walltime: {walltime:.2f}") 319 | lines.append("") 320 | 321 | # Work out column widths and format strings. Very tedious, but neater output. 322 | fldWidth1 = max([len(t) for t in timerList]) 323 | maxTime = max([summaryDict[t]['total'] for t in timerList]) 324 | logMaxTime = numpy.log10(maxTime) 325 | if int(logMaxTime) == logMaxTime: 326 | # maxTime is exact power of 10, so force ceil() to go up anyway 327 | logMaxTime += 0.1 328 | fldWidth2 = 3 + int(numpy.ceil(logMaxTime)) 329 | colHdrFmt = "{:" + str(fldWidth1) + "s} {:>" + str(fldWidth2) + "s}" 330 | lines.append(colHdrFmt.format("Timer", "Total")) 331 | lines.append((3 + fldWidth1 + fldWidth2) * '-') 332 | colFmt = "{:" + str(fldWidth1) + "s} {:" + str(fldWidth2) + ".2f}" 333 | 334 | # Now add the table of timings. 335 | for t in timerList: 336 | line = colFmt.format(t, summaryDict[t]['total']) 337 | lines.append(line) 338 | 339 | outStr = '\n'.join(lines) 340 | return outStr 341 | --------------------------------------------------------------------------------