├── .github └── workflows │ ├── black.yml │ └── isort.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── checkpoints └── .gitignore ├── data └── .gitignore ├── examples ├── predict_detections.ipynb └── predict_detections_sam2.ipynb ├── poetry.lock ├── pyproject.toml ├── sandbox ├── .gitignore ├── detection │ ├── detect_raw_drone_images.ipynb │ ├── detectree2_detections.ipynb │ ├── geometric_detections.ipynb │ ├── plot_detections.ipynb │ └── random_detections.ipynb ├── evaluation │ ├── README.md │ ├── dtree2_benchmark.ipynb │ ├── neon_benchmark.ipynb │ ├── sam2_dtree2_benchmark.ipynb │ └── sam2_neon_benchmark.ipynb ├── postprocessing │ ├── nms_experiments.ipynb │ └── specify_postprocessing_in_detector.ipynb └── preprocessing │ ├── chip_ortho.py │ ├── deepforest_train_and_predict.ipynb │ ├── load_raw_drone_images.ipynb │ └── test-dataloader-with-dtree2.ipynb └── tree_detection_framework ├── __init__.py ├── constants.py ├── detection ├── SAM2_detector.py ├── __init__.py ├── detector.py ├── models.py └── region_detections.py ├── entrypoints ├── __init__.py ├── generate_predictions.py └── tile_data.py ├── evaluation ├── __init__.py └── evaluate.py ├── postprocessing ├── __init__.py └── postprocessing.py ├── preprocessing ├── __init__.py ├── derived_geodatasets.py └── preprocessing.py └── utils ├── __init__.py ├── benchmarking.py ├── detection.py ├── geometric.py ├── geospatial.py └── raster.py /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Black Linter 2 | 3 | # This allows Black be called from the isort workflow 4 | on: 5 | workflow_call: 6 | 7 | jobs: 8 | black-linter: 9 | name: Run Black 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Determine Branch Name 14 | run: | 15 | if [ "${{ github.event_name }}" = "pull_request" ]; then 16 | echo "BRANCH_NAME=${{ github.head_ref }}" >> $GITHUB_ENV 17 | else 18 | BRANCH_NAME=$(echo ${GITHUB_REF#refs/heads/}) 19 | echo "BRANCH_NAME=$BRANCH_NAME" >> $GITHUB_ENV 20 | fi 21 | 22 | - name: Checkout code 23 | uses: actions/checkout@v3 24 | with: 25 | ref: ${{ env.BRANCH_NAME }} 26 | 27 | - name: Pull latest changes 28 | run: git pull origin ${{ env.BRANCH_NAME }} 29 | 30 | - name: Run Black formatter 31 | uses: psf/black@stable 32 | with: 33 | options: "--verbose" 34 | src: "." 35 | jupyter: true 36 | 37 | - name: Push changes 38 | uses: stefanzweifel/git-auto-commit-action@v4 39 | with: 40 | commit_message: Apply black formatting changes -------------------------------------------------------------------------------- /.github/workflows/isort.yml: -------------------------------------------------------------------------------- 1 | name: Code Formatter 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | isort: 11 | name: Run Isort 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Determine Branch Name 16 | run: | 17 | if [ "${{ github.event_name }}" = "pull_request" ]; then 18 | echo "BRANCH_NAME=${{ github.head_ref }}" >> $GITHUB_ENV 19 | else 20 | BRANCH_NAME=$(echo ${GITHUB_REF#refs/heads/}) 21 | echo "BRANCH_NAME=$BRANCH_NAME" >> $GITHUB_ENV 22 | fi 23 | 24 | - name: Checkout code 25 | uses: actions/checkout@v3 26 | with: 27 | ref: ${{ env.BRANCH_NAME }} 28 | 29 | - name: Set up Python 30 | uses: actions/setup-python@v2 31 | with: 32 | python-version: 3.9 33 | 34 | - name: Run Isort 35 | run: pip install isort==5.13.2 && isort --profile black . 36 | 37 | - name: Push changes 38 | uses: stefanzweifel/git-auto-commit-action@v5 39 | with: 40 | commit_message: Apply isort formatting changes 41 | 42 | # Ensure Black runs after Isort has completed 43 | black: 44 | needs: isort 45 | uses: ./.github/workflows/black.yml 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Formatting and style 2 | This repository uses `isort` for import management and `black` for general formatting. Both can be installed using `pip` or `conda` and on a Linux system, they can also be installed at the system level using `apt`. For example, using `pip`, you'd simply do 3 | ``` 4 | pip install isort==5.13.2 5 | pip install black==24.1.0 6 | ``` 7 | We use the default arguments so the tools can be executed by simply pointing them at all files in the repository. From the root directory of the project, you should execute the following: 8 | ``` 9 | isort . 10 | black . 11 | ``` 12 | Note that it's important that `isort` is run first, because it doesn't produce a format that's consistent with `black`. 13 | 14 | If you push changes to main or create a pull request, please be aware that Github Actions will trigger a workflow that runs `isort` and `black` on the code. This will take a few seconds to run and the workflow may automatically push formatting changes to the repository. To ensure your local repository is up to date with the remote repository, wait for a few seconds and pull the latest changes. 15 | 16 | # Branch naming 17 | If you are adding a branch to this repository, please use the following convention: `{feature, bugfix, hotfix, release, docs}/{developer initials}/{short-hyphenated-description}`. For example, `docs/DR/add-branch-naming-convention` for this change. For a description of the prefixes, please see [here](https://medium.com/@abhay.pixolo/naming-conventions-for-git-branches-a-cheatsheet-8549feca2534). 18 | 19 | # Docstrings 20 | For documentation, we use the [Google](https://github.com/NilsJPWerner/autoDocstring/blob/HEAD/docs/google.md) format. I personally use [VSCode autoDocstring](https://marketplace.visualstudio.com/items?itemName=njpwerner.autodocstring) plugin for templating. Keeping the docstrings up-to-date is essential because we automatically integrate them into our documentation. If the docstrings are outdated, the docstrings shown on the documentation will also outdated. 21 | 22 | # Type hints 23 | Typing hints, as introduced by [PEP 484](https://peps.python.org/pep-0484/), are strongly encouraged. This helps provide additional documentation and allows some code editors to make additional autocompletes. 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Open Forest Observatory 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tree-detection-framework 2 | This project has three main goals: 3 | * Enable tree detection on realistic-scale, geospatial raster data with minimal boilerplate, using existing (external) tree detection/segmentation models 4 | * Facilitate direct comparison of multiple algorithms 5 | * Rely on modern libraries and software best practice for a robust, performant, and modular tool 6 | 7 | This project does not, itself, provide tree detection/segmentation algorithms (with the exception of a geometric algorithm). Instead, it provides a standardized interface for performing training, inference, and evaluation using existing tree detection models and algorithms. The project currently supports the external computer vision models DeepForest, Dectree2, and SAM2, as well as a geometric canopy height model segmentor implemented within TDF. Support for other external models can be added by implementing a new `Detector` class. 8 | 9 | We use the `torchgeo` package to perform data loading and standardization using standard geospatial input formats. This library allows us to generate chips on the fly of a given size, stride, and spatial resolution. Training and inference is done with modular detectors that can be based on existing models and algorithms. We have preliminary support for using `PyTorch Lightning` to minimize boilerplate around model training and prediction. Region-level nonmax-suppression (NMS) is done using the `PolyGoneNMS` library which is efficient for large images. Visualization and saving of the predictions is done using `geopandas`, a common library for geospatial data. 10 | 11 | This project is under active development by the [Open Forest Observatory](https://openforestobservatory.org/). We welcome contributions and suggestions for improvement. 12 | 13 | ## Tree detection models supported 14 | TDF currently supports the following tree detection/segmentation algorithms. 15 | 16 | ### DeepForest 17 | - [Github](https://github.com/weecology/DeepForest) 18 | - Uses RGB input data. Predicts tree crowns with rectangular bounding boxes. 19 | - Provides a RetinaNet model trained on a large number of semi-supervised tree crown annotations and a smaller set of manual annotations. 20 | - Trained using data from the US only but representing diverse regions. The model has been applied on data from outside the US successfully. 21 | 22 | ### Detectree2 23 | - [Github](https://github.com/PatBall1/detectree2) 24 | - Uses RGB input data. Predicts tree crowns with polygon boundaries. 25 | - Provides a Mask R-CNN model trained on manually labeled tree crowns from four sites. 26 | - Trained using data from tropical forests. 27 | 28 | ### Segment Anything Model 2 (SAM2) 29 | - [Github](https://github.com/facebookresearch/sam2) 30 | - Uses RGB input data. Predicts objects with polygon boundaries. 31 | - Utilizes the Segment Anything Model (SAM 2.1 Hiera Large) checkpoint with tuned parameters for mask generation optimized for tree crown delineation. 32 | - Does not rely on supervised training for tree-specific data but generalizes well due to SAM's zero-shot nature; however, non-tree objects are also detected and included in predictions. 33 | 34 | ### Geometric Detector 35 | - Implementation of the variable window filter algorithm of [Popescu and Wynne 36 | (2004)](https://www.ingentaconnect.com/content/asprs/pers/2004/00000070/00000005/art00003) for 37 | tree top detection, combined with the algorithm of [Silva et al. 38 | (2016)](https://www.tandfonline.com/doi/full/10.1080/07038992.2016.1196582#abstract) for crown 39 | segmentation. 40 | - Uses canopy height model (CHM) input data. Predicts tree crowns with polygon boundaries. 41 | - This is a learning-free tree detection algorithm. It is the one algorithm that is implemented within TDF as opposed to relying on an existing external model/algorithm. 42 | 43 | ## Software architecture 44 | The `tree-detection-framework` is organized into modular components to facilitate extension including integration of additional detection models. The main components are: 45 | 46 | 1. **`preprocessing.py`**
47 | The `create_dataloader()` method accepts single/multiple orthomosaic inputs. Alternatively, 48 | `create_image_datalaoder()` accepts a folder containing raw drone imagery. The methods tile the 49 | input images based on user-specified parameters such as tile size, stride, and resolution and 50 | return a PyTorch-compatible dataloader for inference. 51 | 2. **`Detector` Base Class**
52 | All detectors in the framework (e.g., DeepForestDetector, Detectree2Detector) inherit from the 53 | `Detector` base class. The base class defines the core logic for generating predictions and 54 | geospatially referencing image tiles, while model-specific detectors translate the inputs to the 55 | format expected by the respective model. This design allows all detectors to plug into the same 56 | pipeline with minimal code changes. 57 | 3. **`RegionDetectionsSet` and `RegionDetections`**
58 | These classes standardize model outputs. A `RegionDetectionsSet` is a collection of `RegionDetections`, where each `RegionDetections` object represents the detections in a single image tile. This abstraction allows postprocessing components to operate uniformly across different detectors. These outputs can be saved out as `.gpkg` or `.geojson` files. 59 | 4. **`postprocessing.py`**
60 | Impelments a set of postprocessing functions for cleaning the detections by Non-Maximum Suppression(NMS), polygon hole suppression, tile boundary suppression, and removing out of bounds detections. Most of these methods operate on standardized output types (`RegionDetections` / `RegionDetectionsSet`). 61 | 62 | ## Install 63 | Some of the dependencies are managed by a tool called [Poetry](https://python-poetry.org/). I've found 64 | easiest to install this using the "official installer" option as follows. Note that this should be run 65 | in the base conda environment or with no environment active. 66 | ``` 67 | curl -sSL https://install.python-poetry.org | python3 - 68 | ``` 69 | Now create and activate a conda environment for the dependencies of this project. 70 | ``` 71 | conda create -n tree-detection-framework python=3.10 -y 72 | conda activate tree-detection-framework 73 | ``` 74 | 75 | Now, from the root directory of the project, run the following command. Note that on Jetstream2, you 76 | may need to run this in a graphical session and respond to a keyring popup menu. 77 | ``` 78 | poetry install 79 | ``` 80 | Finally, choose to either install the Detectron2 or SAM2 detection framework. 81 | 82 | **Detectron2:** 83 | The Detectron2 library is not compatible with `poetry` so must be installed directly with pip 84 | ``` 85 | # https://detectron2.readthedocs.io/en/latest/tutorials/install.html#build-detectron2-from-source 86 | pip install git+https://github.com/facebookresearch/detectron2.git 87 | ``` 88 | Download the detectree2 checkpoint weights. 89 | ``` 90 | cd checkpoints 91 | mkdir detectree2 92 | cd detectree2 93 | wget https://zenodo.org/records/10522461/files/230103_randresize_full.pth 94 | ``` 95 | **SAM2:** 96 | Clone the SAM2 repository and install the necessary config files 97 | ``` 98 | git clone https://github.com/facebookresearch/sam2.git && cd sam2 99 | 100 | pip install -e . 101 | ``` 102 | And download the associated checkpoints 103 | ``` 104 | cd checkpoints && \ 105 | ./download_ckpts.sh && \ 106 | cd .. 107 | ``` 108 | And move into this repo 109 | ``` 110 | mv checkpoints ../tree-detection-framework 111 | ``` 112 | 113 | 114 | ## Use 115 | The module code is in the `tree_detection_framework` folder. Once installed using the `poetry` 116 | command above, this code can be imported into scripts or notebooks under the name 117 | `tree_detection_framework` the same as you would for any other library. 118 | 119 | ## Examples 120 | To begin with, you can access example geospatial data 121 | [here](https://ucdavis.box.com/v/tdf-example-data), which should be downloaded and placed in the `data` folder at the top level of this project. Our goal is to have high-quality, 122 | up-to-date examples in the `examples` folder. We also have work-in-progress or one-off code in 123 | `sandbox`, which still may provide some insight but is not guaranteed to be current or generalizable. 124 | Finally, the `tree_detection_framework/entrypoints` folder has command line scripts that can be run 125 | to complete tasks. 126 | 127 | ## Evaluation and benchmark with NEON 128 | Download the NEON dataset files and save the annotations and RGB folders under a new directory in the `data` folder. 129 | ``` 130 | wget -O annotations.zip "https://zenodo.org/records/5914554/files/annotations.zip?download=1" 131 | unzip annotations.zip 132 | wget -O evaluation.zip "https://zenodo.org/records/5914554/files/evaluation.zip?download=1" 133 | unzip -j evaluation.zip "evaluation/RGB/*" -d RGB 134 | rm annotations.zip 135 | rm evaluation.zip 136 | ``` 137 | Follow the steps in `tree-detection-framework/sandbox/evaluation/neon_benchmark.ipynb` for detectors `DeepForest` & `Detectree2`, and `tree-detection-framework/sandbox/evaluation/sam2_neon_benchmark.ipynb` to use `SAM2`. 138 | 139 | ## Evaluation and benchmark with Detectree2 datasets 140 | 1. Download the dataset. There are two ways to get the dataset: 141 | Download the site-specific .tif (for orthomosaic) and .gpkg (for ground truth polygons) files from https://zenodo.org/records/8136161. Then, follow steps in https://github.com/PatBall1/detectree2/blob/master/notebooks/colab/tilingJB.ipynb to do the the tiling. 142 | (OR) 143 | Download our pre-tiled dataset from https://ucdavis.box.com/s/thjmaane9d38opw1bhnyxrsrtt90j37m 144 | 3. Add the tiled dataset folder to the `data` folder in this repo. 145 | 4. For benchmark and evaluation see steps in `tree-detection-framework/sandbox/evaluation/dtree2_benchmark.ipynb` and `tree-detection-framework/sandbox/evaluation/sam2_dtree2_benchmark.ipynb` 146 | -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "tree-detection-framework" 3 | version = "0.1.0" 4 | description = "Framework for detecting trees in remote sensing data" 5 | authors = ["David Russell "] 6 | license = "BSD-3" 7 | readme = "README.md" 8 | packages = [{include = "tree_detection_framework"}] 9 | 10 | [tool.poetry.dependencies] 11 | python = ">=3.10,<3.12" 12 | torchgeo = "^0.6.0" 13 | deepforest = "^1.3.3" 14 | ipykernel = "^6.29.5" 15 | polygonenms-ofo-fork = "^0.1.0" 16 | 17 | 18 | [build-system] 19 | requires = ["poetry-core"] 20 | build-backend = "poetry.core.masonry.api" 21 | -------------------------------------------------------------------------------- /sandbox/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/sandbox/.gitignore -------------------------------------------------------------------------------- /sandbox/evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Zero-Shot Tree Detection and Segmentation from Aerial Forest Imagery Experiments 2 | These notebooks provide a reimplementation of experiments presented in this [paper](https://ml-for-rs.github.io/iclr2025/camera_ready/papers/3.pdf) presented at the ([ML4RS25](https://ml-for-rs.github.io/iclr2025/)) workshop. These experiments differ from the initial experiments in several notable ways, leading to different results. 3 | - A subtlety in the original implementation of Detectree2 meant that in the original experiments, inference on the NEON datasets occured with a flipped channel ordering, issue [here](https://github.com/PatBall1/detectree2/issues/197). Fixing this issue lead to better performance. 4 | - The original experiments first converted the polygon predictions to axis-aligned bounding boxes prior to running non-max suppression. In the current experiments, this was changed to running NMS on the polygon representation prior to converting to boxes. This change keeps more detections, leading to higher precision and lower recall than the original experiments. 5 | - Aside from the polygon vs. box NMS consideration, SAM2 results differ for reasons that are not understood. 6 | 7 | The data for these experiments can be downloaded following the instructions for the [NEON](https://github.com/open-forest-observatory/tree-detection-framework?tab=readme-ov-file#evaluation-and-benchmark-with-neon) and [Detectree2](https://github.com/open-forest-observatory/tree-detection-framework?tab=readme-ov-file#evaluation-and-benchmark-with-detectree2-datasets) datasets. You will need to create two seperate environments, one with the Detectree2 dependencies installed and the other with SAM2 dependencies installed following the instructions [here](https://github.com/open-forest-observatory/tree-detection-framework?tab=readme-ov-file#install). Then run the notebooks using the appropriate environment for the prediction model being used. 8 | -------------------------------------------------------------------------------- /sandbox/preprocessing/chip_ortho.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import random 5 | from pathlib import Path 6 | from typing import Any, Dict, Optional 7 | 8 | import fiona 9 | import fiona.transform 10 | import geopandas as gpd 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pyproj 14 | import rasterio 15 | import shapely.geometry 16 | import torch 17 | from shapely.affinity import affine_transform 18 | from shapely.geometry import box 19 | from torch.utils.data import DataLoader 20 | from torchgeo.datasets import ( 21 | IntersectionDataset, 22 | RasterDataset, 23 | VectorDataset, 24 | stack_samples, 25 | unbind_samples, 26 | ) 27 | from torchgeo.datasets.utils import BoundingBox, array_to_tensor 28 | from torchgeo.samplers import GridGeoSampler, Units 29 | from torchvision.transforms import ToPILImage 30 | 31 | # Set up logging configuration 32 | logging.basicConfig( 33 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 34 | ) 35 | 36 | 37 | class CustomRasterDataset(RasterDataset): 38 | """ 39 | Custom dataset class for orthomosaic raster images. This class extends the `RasterDataset` from `torchgeo`. 40 | 41 | Attributes: 42 | filename_glob (str): Glob pattern to match files in the directory. 43 | is_image (bool): Indicates that the data being loaded is image data. 44 | separate_files (bool): True if data is stored in a separate file for each band, else False. 45 | """ 46 | 47 | filename_glob: str = "*.tif" # To match all TIFF files 48 | is_image: bool = True 49 | separate_files: bool = False 50 | 51 | 52 | class CustomVectorDataset(VectorDataset): 53 | """ 54 | Custom dataset class for vector data which act as labels for the raster data. This class extends the `VectorDataset` from `torchgeo`. 55 | """ 56 | 57 | def __getitem__(self, query: BoundingBox) -> dict[str, Any]: 58 | """Retrieve image/mask and metadata indexed by query. 59 | This function is largely based on the `__getitem__` method from torchgeo's `VectorDataset`, with custom modifications for this implementation. 60 | 61 | Args: 62 | query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index 63 | 64 | Returns: 65 | sample of image/mask and metadata at that index 66 | 67 | Raises: 68 | IndexError: if query is not found in the index 69 | """ 70 | hits = self.index.intersection(tuple(query), objects=True) 71 | filepaths = [hit.object for hit in hits] 72 | 73 | if not filepaths: 74 | raise IndexError( 75 | f"query: {query} not found in index with bounds: {self.bounds}" 76 | ) 77 | 78 | shapes = [] 79 | for filepath in filepaths: 80 | with fiona.open(filepath) as src: 81 | # We need to know the bounding box of the query in the source CRS 82 | (minx, maxx), (miny, maxy) = fiona.transform.transform( 83 | self.crs.to_dict(), 84 | src.crs, 85 | [query.minx, query.maxx], 86 | [query.miny, query.maxy], 87 | ) 88 | 89 | # Filter geometries to those that intersect with the bounding box 90 | for feature in src.filter(bbox=(minx, miny, maxx, maxy)): 91 | # Warp geometries to requested CRS 92 | shape = fiona.transform.transform_geom( 93 | src.crs, self.crs.to_dict(), feature["geometry"] 94 | ) 95 | label = self.get_label(feature) 96 | shapes.append((shape, label)) 97 | 98 | # Rasterize geometries 99 | width = (query.maxx - query.minx) / self.res 100 | height = (query.maxy - query.miny) / self.res 101 | transform = rasterio.transform.from_bounds( 102 | query.minx, query.miny, query.maxx, query.maxy, width, height 103 | ) 104 | if shapes: 105 | masks = rasterio.features.rasterize( 106 | shapes, out_shape=(round(height), round(width)), transform=transform 107 | ) 108 | else: 109 | # If no features are found in this query, return an empty mask 110 | # with the default fill value and dtype used by rasterize 111 | masks = np.zeros((round(height), round(width)), dtype=np.uint8) 112 | 113 | # Use array_to_tensor since rasterize may return uint16/uint32 arrays. 114 | masks = array_to_tensor(masks) 115 | 116 | masks = masks.to(self.dtype) 117 | 118 | # Beginning of additions made to this function 119 | 120 | # Invert the transform to convert geo coordinates to pixel values 121 | inverse_transform = ~transform 122 | 123 | # Convert `fiona` type shapes to `shapely` shape objects for easier manipulation 124 | shapely_shapes = [(shapely.geometry.shape(sh), i) for sh, i in shapes] 125 | 126 | # Apply the inverse transform to each shapely shape, converting geo coordinates to pixel coordinates 127 | pixel_transformed_shapes = [ 128 | (affine_transform(sh, inverse_transform.to_shapely()), i) 129 | for sh, i in shapely_shapes 130 | ] 131 | 132 | # Add 'shapes' containing polygons and corresponding ID values 133 | sample = { 134 | "mask": masks, 135 | "crs": self.crs, 136 | "bounds": query, 137 | "shapes": pixel_transformed_shapes, 138 | } 139 | 140 | if self.transforms is not None: 141 | sample = self.transforms(sample) 142 | 143 | return sample 144 | 145 | 146 | def chip_orthomosaics( 147 | raster_path: str, 148 | size: float, 149 | vector_path: Optional[str] = None, 150 | id_column_name: str = "treeID", 151 | stride: Optional[float] = None, 152 | overlap_percent: Optional[float] = None, 153 | res: Optional[float] = None, 154 | use_units_meters: bool = False, 155 | save_dir: Optional[str] = None, 156 | visualize_n: Optional[int] = None, 157 | ) -> DataLoader: 158 | """ 159 | Splits an orthomosaic image into smaller tiles with optional reprojection to a meters-based CRS. Tiles can be saved to a directory and visualized. 160 | 161 | Args: 162 | raster_path (str): Path to the folder containing the orthomosaic files. 163 | size (float): Tile size in units of pixels or meters, depending on `use_units_meters`. 164 | vector_path (str, optional): Path to the folder containing the vector data files. 165 | id_column_name (str): Column name in the vector dataframe containing IDs for the tree polygons. Defaults to 'treeID'. 166 | stride (float, optional): The distance between the start of one tile and the next in pixels or meters. 167 | overlap (float, optional): Percentage overlap between consecutive tiles (0-100%). Used to calculate stride if provided. 168 | res (float, optional): Resolution of the dataset in units of the CRS (if not specified, defaults to the resolution of the first image). 169 | use_units_meters (bool, optional): Whether to use meters instead of pixels for tile size and stride. 170 | save_dir (str, optional): Directory where the tiles and metadata should be saved. 171 | visualize_n (int, optional): Number of randomly selected tiles to visualize. 172 | 173 | Returns: 174 | A dataloader with chipped orthomosaic tiles. 175 | 176 | Raises: 177 | ValueError: If neither `stride` nor `overlap` are provided. 178 | """ 179 | 180 | # Stores image data 181 | raster_dataset = CustomRasterDataset(paths=raster_path, res=res) 182 | 183 | # Stores label data (hardcoded label_name for now) 184 | vector_dataset = ( 185 | CustomVectorDataset(paths=vector_path, res=res, label_name=id_column_name) 186 | if vector_path is not None 187 | else None 188 | ) 189 | 190 | units = Units.CRS if use_units_meters == True else Units.PIXELS 191 | logging.info(f"Units = {units}") 192 | 193 | if use_units_meters and raster_dataset.crs.is_geographic: 194 | # Reproject the dataset to a meters-based CRS 195 | logging.info("Projecting to meters-based CRS...") 196 | lat, lon = raster_dataset.bounds[2], raster_dataset.bounds[0] 197 | 198 | # Return a new projected CRS value with meters units 199 | projected_crs = get_projected_CRS(lat, lon) 200 | 201 | # Type conversion to rasterio.crs 202 | projected_crs = rasterio.crs.CRS.from_wkt(projected_crs.to_wkt()) 203 | 204 | # Recreating the raster and vector dataset objects with the new CRS value 205 | raster_dataset = CustomRasterDataset(paths=raster_path, crs=projected_crs) 206 | vector_dataset = ( 207 | CustomVectorDataset( 208 | paths=vector_path, crs=projected_crs, label_name=id_column_name 209 | ) 210 | if vector_path 211 | else None 212 | ) 213 | 214 | # Create an intersection dataset that combines raster and label data if given. Otherwise, proceed with just raster_dataset. 215 | final_dataset = ( 216 | IntersectionDataset(raster_dataset, vector_dataset) 217 | if vector_path is not None 218 | else raster_dataset 219 | ) 220 | 221 | # Calculate stride if overlap is provided 222 | if overlap_percent: 223 | stride = size * (1 - overlap_percent / 100.0) 224 | logging.info(f"Calculated stride based on overlap: {stride}") 225 | elif stride is None: 226 | raise ValueError("Either 'stride' or 'overlap' must be provided.") 227 | logging.info(f"Stride = {stride}") 228 | 229 | # GridGeoSampler to get contiguous tiles 230 | sampler = GridGeoSampler(final_dataset, size=size, stride=stride, units=units) 231 | dataloader = DataLoader(final_dataset, sampler=sampler, collate_fn=stack_samples) 232 | 233 | if visualize_n: 234 | # Randomly pick indices for visualizing tiles if visualize_n is specified 235 | visualize_indices = random.sample(range(len(sampler)), visualize_n) 236 | 237 | for i in visualize_indices: 238 | plot(get_sample_from_index(raster_dataset, sampler, i)) 239 | plt.axis("off") 240 | plt.show() 241 | 242 | if save_dir: 243 | # Creates save directory if it doesn't exist 244 | save_path = Path(save_dir) 245 | save_path.mkdir(parents=True, exist_ok=True) 246 | 247 | transform_to_pil = ToPILImage() 248 | for i, batch in enumerate(dataloader): 249 | sample = unbind_samples(batch)[0] 250 | 251 | image = sample["image"] 252 | image_tensor = torch.clamp(image / 255.0, min=0, max=1) 253 | pil_image = transform_to_pil(image_tensor) 254 | pil_image.save(Path(save_dir) / f"tile_{i}.png") 255 | 256 | # Prepare to save tile metadata 257 | metadata = { 258 | "crs": sample["crs"].to_string(), 259 | "bounds": list(sample["bounds"]), 260 | } 261 | 262 | if vector_path: 263 | # Extract shapes (polygons and tree IDs) 264 | shapes = sample["shapes"] 265 | 266 | crowns = [ 267 | {"ID": tree_id, "crown": polygon.wkt} for polygon, tree_id in shapes 268 | ] 269 | 270 | # Add crowns to the metadata 271 | metadata["crowns"] = crowns 272 | 273 | # Save tile metadata to a json file 274 | with open(Path(save_dir) / f"tile_{i}.json", "w") as f: 275 | json.dump(metadata, f, indent=4) 276 | 277 | logging.info(f"Saved {i + 1} tiles to {save_dir}") 278 | 279 | return dataloader 280 | 281 | 282 | # Helper functions 283 | 284 | 285 | def get_sample_from_index( 286 | dataset: CustomRasterDataset, sampler: GridGeoSampler, index: int 287 | ) -> Dict: 288 | # Access the specific index from the sampler containing bounding boxes 289 | sample_indices = list(sampler) 290 | sample_idx = sample_indices[index] 291 | 292 | # Get the sample from the dataset using this index 293 | sample = dataset[sample_idx] 294 | return sample 295 | 296 | 297 | def plot(sample: Dict) -> plt.Figure: 298 | image = sample["image"].permute(1, 2, 0) 299 | image = image.byte().numpy() 300 | fig, ax = plt.subplots() 301 | ax.imshow(image) 302 | return fig 303 | 304 | 305 | def get_projected_CRS( 306 | lat: float, lon: float, assume_western_hem: bool = True 307 | ) -> pyproj.CRS: 308 | if assume_western_hem and lon > 0: 309 | lon = -lon 310 | epgs_code = 32700 - round((45 + lat) / 90) * 100 + round((183 + lon) / 6) 311 | crs = pyproj.CRS.from_epsg(epgs_code) 312 | return crs 313 | 314 | 315 | def parse_args() -> argparse.Namespace: 316 | parser = argparse.ArgumentParser(description="Chipping orthomosaic images") 317 | parser.add_argument( 318 | "--raster-path", 319 | type=str, 320 | required=True, 321 | help="Path to folder containing single or multiple orthomosaic images.", 322 | ) 323 | parser.add_argument( 324 | "--vector-path", 325 | type=str, 326 | required=False, 327 | help="Path to folder containing single or multiple vector datafiles.", 328 | ) 329 | parser.add_argument( 330 | "--id-column-name", 331 | type=str, 332 | default="treeID", 333 | help="Column name in the vector dataframe containing IDs for the tree polygons. Defaults to 'treeID'.", 334 | ) 335 | parser.add_argument( 336 | "--res", 337 | type=float, 338 | required=False, 339 | help="Resolution of the dataset in units of CRS (defaults to the resolution of the first file found)", 340 | ) 341 | parser.add_argument( 342 | "--size", 343 | type=float, 344 | required=True, 345 | help="Single value used for height and width dim", 346 | ) 347 | parser.add_argument( 348 | "--stride", 349 | type=float, 350 | required=False, 351 | help="Distance to skip between each patch", 352 | ) 353 | parser.add_argument( 354 | "--overlap-percent", 355 | type=float, 356 | required=False, 357 | help="Percentage overlap between the tiles (0-100%)", 358 | ) 359 | parser.add_argument( 360 | "--use-units-meters", 361 | action="store_true", 362 | help="Whether to set units for tile size and stide as meters", 363 | ) 364 | parser.add_argument( 365 | "--save-dir", type=str, required=False, help="Directory to save chips" 366 | ) 367 | parser.add_argument( 368 | "--visualize-n", type=int, required=False, help="Number of tiles to visualize" 369 | ) 370 | 371 | args = parser.parse_args() 372 | return args 373 | 374 | 375 | if __name__ == "__main__": 376 | args = parse_args() 377 | chip_orthomosaics(**args.__dict__) 378 | -------------------------------------------------------------------------------- /tree_detection_framework/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/__init__.py -------------------------------------------------------------------------------- /tree_detection_framework/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union 3 | 4 | import geopandas as gpd 5 | import numpy.typing 6 | import shapely 7 | import torch 8 | 9 | PATH_TYPE = Union[str, Path] 10 | BOUNDARY_TYPE = Union[ 11 | PATH_TYPE, shapely.Polygon, shapely.MultiPolygon, gpd.GeoDataFrame, gpd.GeoSeries 12 | ] 13 | ARRAY_TYPE = numpy.typing.ArrayLike 14 | 15 | DATA_FOLDER = Path(Path(__file__).parent, "..", "data").resolve() 16 | CHECKPOINTS_FOLDER = Path(Path(__file__).parent, "..", "checkpoints").resolve() 17 | 18 | DEFAULT_DEVICE = ( 19 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 20 | ) 21 | -------------------------------------------------------------------------------- /tree_detection_framework/detection/SAM2_detector.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from shapely.geometry import box 5 | 6 | from tree_detection_framework.constants import CHECKPOINTS_FOLDER, DEFAULT_DEVICE 7 | from tree_detection_framework.detection.detector import Detector 8 | from tree_detection_framework.utils.geometric import mask_to_shapely 9 | 10 | try: 11 | from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator 12 | from sam2.build_sam import build_sam2 13 | 14 | SAM2_AVAILABLE = True 15 | except ImportError: 16 | SAM2_AVAILABLE = False 17 | raise ImportError( 18 | "SAM2 is not installed. Please install it using the instructions in the README." 19 | ) 20 | 21 | 22 | # follow README for download instructions 23 | class SAMV2Detector(Detector): 24 | 25 | def __init__( 26 | self, 27 | device=DEFAULT_DEVICE, 28 | sam2_checkpoint=Path(CHECKPOINTS_FOLDER, "sam2.1_hiera_large.pt"), 29 | model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml", 30 | postprocessors=None, 31 | ): 32 | """ 33 | Create a SAM2 detector. 34 | Args: 35 | device (torch.device): Device to run the model on. 36 | sam2_checkpoint (Path): Path to the SAM2 checkpoint. 37 | model_cfg (str): Path to the SAM2 model config. 38 | postprocessors (list, optional): See docstring for Detector class. Defaults to None. 39 | """ 40 | super().__init__(postprocessors=postprocessors) 41 | 42 | self.device = device 43 | 44 | self.sam2 = build_sam2( 45 | model_cfg, sam2_checkpoint, device=self.device, apply_postprocessing=False 46 | ) 47 | # Create the mask generator with the optimal set of parameters found by 48 | # Michelle Chen & Jane Wu based on qualitative experiments 49 | self.mask_generator = SAM2AutomaticMaskGenerator( 50 | model=self.sam2, 51 | points_per_side=64, 52 | points_per_batch=128, 53 | pred_iou_thresh=0.7, 54 | stability_score_thresh=0.92, 55 | stability_score_offset=0.7, 56 | crop_n_layers=1, 57 | box_nms_thresh=0.7, 58 | crop_n_points_downscale_factor=2, 59 | min_mask_region_area=25.0, 60 | use_m2m=True, 61 | ) 62 | 63 | def call_predict(self, batch): 64 | """ 65 | Args: 66 | batch (Tensor): 4 dims Tensor with the first dimension having number of images in the batch 67 | 68 | Returns: 69 | masks List[List[Dict]]: list of dictionaries for each mask in the batch 70 | """ 71 | 72 | with torch.no_grad(): 73 | masks = [] 74 | for original_image in batch: 75 | if original_image.shape[0] < 3: 76 | raise ValueError("Original image has less than 3 channels") 77 | 78 | original_image = original_image.permute(1, 2, 0) 79 | # If the pixels are in [0, 255] range, convert to [0, 1] range 80 | if original_image.max() > 1: 81 | original_image = original_image.byte().numpy() 82 | else: 83 | original_image = original_image.numpy() 84 | rgb_image = original_image[:, :, :3] 85 | mask = self.mask_generator.generate( 86 | rgb_image 87 | ) # model expects rgb 0-255 range (h, w, 3) 88 | # FUTURE TODO: Support batched predictions 89 | masks.append(mask) 90 | 91 | return masks 92 | 93 | def predict_batch(self, batch): 94 | """ 95 | Get predictions for a batch of images. 96 | 97 | Args: 98 | batch (defaultDict): A batch from the dataloader 99 | 100 | Returns: 101 | all_geometries (List[List[shapely.MultiPolygon]]): 102 | A list of predictions one per image in the batch. The predictions for each image 103 | are a list of shapely objects. 104 | all_data_dicts (Union[None, List[dict]]): 105 | Predicted scores and classes 106 | """ 107 | images = batch["image"] 108 | 109 | # computational bottleneck 110 | batch_preds = self.call_predict(images) 111 | 112 | # To store all predicted polygons 113 | all_geometries = [] 114 | # To store other related information such as scores and labels 115 | all_data_dicts = [] 116 | 117 | # Iterate through predictions for each tile in the batch 118 | for pred in batch_preds: 119 | 120 | # Get the Instances object 121 | segmentations = [dic["segmentation"].astype(float) for dic in pred] 122 | 123 | # Convert each mask to a shapely multipolygon 124 | shapely_objects = [ 125 | mask_to_shapely(pred_mask) for pred_mask in segmentations 126 | ] 127 | 128 | all_geometries.append(shapely_objects) 129 | 130 | # Compute axis-aligned minimum area bounding box as Polygon objects 131 | bounding_boxes = [box(*polygon.bounds) for polygon in shapely_objects] 132 | 133 | # Get prediction scores 134 | scores = [dic["stability_score"] for dic in pred] 135 | all_data_dicts.append({"score": scores, "bbox": bounding_boxes}) 136 | 137 | return all_geometries, all_data_dicts 138 | -------------------------------------------------------------------------------- /tree_detection_framework/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/detection/__init__.py -------------------------------------------------------------------------------- /tree_detection_framework/detection/models.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, List, Optional 3 | 4 | import lightning 5 | import torch 6 | import torchvision 7 | from deepforest import main as deepforest_main 8 | from torch import Tensor, optim 9 | from torchvision.models.detection.retinanet import ( 10 | AnchorGenerator, 11 | RetinaNet, 12 | RetinaNet_ResNet50_FPN_Weights, 13 | ) 14 | 15 | from tree_detection_framework.utils.detection import use_release_df 16 | 17 | try: 18 | from detectron2 import model_zoo 19 | from detectron2.config import get_cfg 20 | 21 | DETECTRON2_AVAILABLE = True 22 | except ImportError: 23 | DETECTRON2_AVAILABLE = False 24 | 25 | 26 | class RetinaNetModel: 27 | """A backbone class for DeepForest""" 28 | 29 | def __init__(self, param_dict): 30 | self.param_dict = param_dict 31 | 32 | def load_backbone(self): 33 | """A torch vision retinanet model""" 34 | backbone = torchvision.models.detection.retinanet_resnet50_fpn( 35 | weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1 36 | ) 37 | 38 | return backbone 39 | 40 | def create_anchor_generator( 41 | self, sizes=((8, 16, 32, 64, 128, 256, 400),), aspect_ratios=((0.5, 1.0, 2.0),) 42 | ): 43 | """Create anchor box generator as a function of sizes and aspect ratios""" 44 | anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios) 45 | 46 | return anchor_generator 47 | 48 | def create_model(self): 49 | """Create a retinanet model 50 | 51 | Returns: 52 | model: a pytorch nn module 53 | """ 54 | resnet = self.load_backbone() 55 | backbone = resnet.backbone 56 | 57 | model = RetinaNet(backbone=backbone, num_classes=self.param_dict["num_classes"]) 58 | # TODO: do we want to set model.nms_thresh and model.score_thresh? 59 | 60 | return model 61 | 62 | 63 | class DeepForestModule(lightning.LightningModule): 64 | def __init__( 65 | self, 66 | use_hugging_face_weights: bool = True, 67 | param_dict: Optional[Dict[str, Any]] = {}, 68 | ): 69 | """_summary_ 70 | 71 | Args: 72 | use_hugging_face_weights (bool, optional): 73 | Should the model and weights be donwloaded from the Hugging Face repository. If 74 | True, setting param_dict will result in an error. Defaults to True. 75 | param_dict (Optional[Dict[str, Any]], optional): 76 | Configuration parameters to provide finer control over the model. Cannot be used 77 | with use_hugging_face_weights=True. Defaults to {}. 78 | 79 | Raises: 80 | ValueError: If both use_hugging_face_weights and param_dict are set 81 | """ 82 | super().__init__() 83 | self.param_dict = param_dict 84 | 85 | # Determine how to obtain the weights 86 | if use_hugging_face_weights: 87 | # Error if the user tried to use both options 88 | if len(param_dict) > 0: 89 | raise ValueError( 90 | "Setting the `param_dict` and `use_hugging_face_weights`=True are mutually exclusive. Please choose one option." 91 | ) 92 | # Create the architecture 93 | model = deepforest_main.deepforest() 94 | # Load a pretrained tree detection model from Hugging Face 95 | model.load_model(model_name="weecology/deepforest-tree", revision="main") 96 | # The definition of the model here contains additional metrics that do not match what we 97 | # need. Take only the prediction model portion. 98 | self.model = model.model 99 | 100 | else: 101 | if param_dict["backbone"] == "retinanet": 102 | retinanet = RetinaNetModel(param_dict) 103 | else: 104 | raise ValueError("Only 'retinanet' backbone is currently supported.") 105 | 106 | self.model = retinanet.create_model() 107 | self.use_release() 108 | 109 | def use_release(self, check_release=True): 110 | """Use the latest DeepForest model release from github and load model. 111 | Optionally download if release doesn't exist. 112 | Args: 113 | check_release (logical): whether to check github for a model recent release. 114 | In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. 115 | If no model has been downloaded an error will raise. 116 | """ 117 | # Download latest model from github release 118 | release_tag, self.release_state_dict = use_release_df( 119 | check_release=check_release 120 | ) 121 | self.model.load_state_dict(torch.load(self.release_state_dict)) 122 | 123 | # load saved model and tag release 124 | self.__release_version__ = release_tag 125 | print("Loading pre-built model: {}".format(release_tag)) 126 | 127 | def forward( 128 | self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None 129 | ) -> Dict[str, Tensor]: 130 | """Calls the model's forward method. 131 | Args: 132 | images (list[Tensor]): Images to be processed 133 | targets (list[Dict[Tensor]]): Ground-truth boxes present in the image (optional) 134 | 135 | Returns: 136 | result (list[BoxList] or dict[Tensor]): 137 | The output from the model. 138 | During training, it returns a dict[Tensor] which contains the losses. 139 | """ 140 | # Move the data to the same device as the model 141 | images = images.to(self.device) 142 | return self.model.forward(images, targets=targets) # Model specific forward 143 | 144 | def training_step(self, batch): 145 | # Ensure model is in train mode 146 | self.model.train() 147 | device = next(self.model.parameters()).device 148 | 149 | # Image is expected to be a list of tensors, each of shape [C, H, W] in 0-1 range. 150 | image_batch = (batch["image"][:, :3, :, :] / 255.0).to(device) 151 | image_batch_list = [image for image in image_batch] 152 | 153 | # To store every image's target - a dictionary containing `boxes` and `labels` 154 | targets = [] 155 | for tile in batch["bounding_boxes"]: 156 | # Convert from list to FloatTensor[N, 4] 157 | boxes_tensor = torch.tensor(tile, dtype=torch.float32).to(device) 158 | # Need to remove boxes that go out-of-bounds. Has negative values. 159 | valid_mask = (boxes_tensor >= 0).all(dim=1) 160 | filtered_boxes_tensor = boxes_tensor[valid_mask] 161 | # Create a label tensor. Single class for now. 162 | class_labels = torch.zeros( 163 | filtered_boxes_tensor.shape[0], dtype=torch.int64 164 | ).to(device) 165 | # Dictionary for the tile 166 | d = {"boxes": filtered_boxes_tensor, "labels": class_labels} 167 | targets.append(d) 168 | 169 | loss_dict = self.forward(image_batch_list, targets=targets) 170 | 171 | final_loss = sum([loss for loss in loss_dict.values()]) 172 | print("loss: ", final_loss) 173 | return final_loss 174 | 175 | def configure_optimizers(self): 176 | # similar to the one in deepforest 177 | optimizer = optim.SGD( 178 | self.model.parameters(), lr=self.param_dict["train"]["lr"], momentum=0.9 179 | ) 180 | 181 | # TODO: Setup lr_scheduler 182 | # TODO: Return 'optimizer', 'lr_scheduler', 'monitor' when validation data is set 183 | 184 | return optimizer 185 | 186 | 187 | class Detectree2Module: 188 | def __init__(self, param_dict: Optional[Dict[str, Any]] = None): 189 | if DETECTRON2_AVAILABLE is False: 190 | raise ImportError( 191 | "detectron2 is not installed. Please install it to use this module." 192 | ) 193 | super().__init__() 194 | # If param_dict is not provided, ensure it is an empty dictionary 195 | self.param_dict = param_dict or {} 196 | self.cfg = self.setup_cfg(**self.param_dict) 197 | 198 | def setup_cfg( 199 | self, 200 | base_model: str = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml", 201 | trains=("trees_train",), 202 | tests=("trees_val",), 203 | update_model=None, 204 | workers=2, 205 | ims_per_batch=2, 206 | gamma=0.1, 207 | backbone_freeze=3, 208 | warm_iter=120, 209 | momentum=0.9, 210 | batch_size_per_im=1024, 211 | base_lr=0.0003389, 212 | weight_decay=0.001, 213 | max_iter=1000, 214 | num_classes=1, 215 | eval_period=100, 216 | out_dir="./train_outputs", 217 | resize=True, 218 | ): 219 | """Set up config object. 220 | Args: 221 | base_model: base pre-trained model from detectron2 model_zoo 222 | trains: names of registered data to use for training 223 | tests: names of registered data to use for evaluating models 224 | update_model: updated pre-trained model from detectree2 model_garden 225 | workers: number of workers for dataloader 226 | ims_per_batch: number of images per batch 227 | gamma: gamma for learning rate scheduler 228 | backbone_freeze: backbone layer to freeze 229 | warm_iter: number of iterations for warmup 230 | momentum: momentum for optimizer 231 | batch_size_per_im: batch size per image 232 | base_lr: base learning rate 233 | weight_decay: weight decay for optimizer 234 | max_iter: maximum number of iterations 235 | num_classes: number of classes 236 | eval_period: number of iterations between evaluations 237 | out_dir: directory to save outputs 238 | resize: whether to resize input images 239 | """ 240 | # Initialize configuration 241 | cfg = get_cfg() 242 | cfg.merge_from_file(model_zoo.get_config_file(base_model)) 243 | 244 | # Assign values, prioritizing those in param_dict 245 | cfg.DATASETS.TRAIN = self.param_dict.get("trains", trains) 246 | cfg.DATASETS.TEST = self.param_dict.get("tests", tests) 247 | cfg.DATALOADER.NUM_WORKERS = self.param_dict.get("workers", workers) 248 | cfg.SOLVER.IMS_PER_BATCH = self.param_dict.get("ims_per_batch", ims_per_batch) 249 | cfg.SOLVER.GAMMA = self.param_dict.get("gamma", gamma) 250 | cfg.MODEL.BACKBONE.FREEZE_AT = self.param_dict.get( 251 | "backbone_freeze", backbone_freeze 252 | ) 253 | cfg.SOLVER.WARMUP_ITERS = self.param_dict.get("warm_iter", warm_iter) 254 | cfg.SOLVER.MOMENTUM = self.param_dict.get("momentum", momentum) 255 | cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE = self.param_dict.get( 256 | "batch_size_per_im", batch_size_per_im 257 | ) 258 | cfg.SOLVER.WEIGHT_DECAY = self.param_dict.get("weight_decay", weight_decay) 259 | cfg.SOLVER.BASE_LR = self.param_dict.get("base_lr", base_lr) 260 | cfg.OUTPUT_DIR = self.param_dict.get("out_dir", out_dir) 261 | cfg.SOLVER.MAX_ITER = self.param_dict.get("max_iter", max_iter) 262 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = self.param_dict.get( 263 | "num_classes", num_classes 264 | ) 265 | cfg.TEST.EVAL_PERIOD = self.param_dict.get("eval_period", eval_period) 266 | cfg.RESIZE = self.param_dict.get("resize", resize) 267 | cfg.INPUT.MIN_SIZE_TRAIN = 1000 268 | 269 | # Create output directory if it doesn't exist 270 | Path(cfg.OUTPUT_DIR).mkdir(parents=True, exist_ok=True) 271 | 272 | # Set model weights 273 | cfg.MODEL.WEIGHTS = self.param_dict.get( 274 | "update_model", update_model 275 | ) or model_zoo.get_checkpoint_url(base_model) 276 | 277 | return cfg 278 | 279 | 280 | # future TODO: add module configs for sam2, currently implemented for default configs 281 | -------------------------------------------------------------------------------- /tree_detection_framework/detection/region_detections.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from pathlib import Path 3 | from typing import List, Optional, Union 4 | 5 | import geopandas as gpd 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | import pyproj 10 | import rasterio.plot 11 | import rasterio.transform 12 | import shapely 13 | from shapely.affinity import affine_transform 14 | 15 | from tree_detection_framework.constants import PATH_TYPE 16 | from tree_detection_framework.utils.raster import show_raster 17 | 18 | 19 | def plot_detections( 20 | data_frame: gpd.GeoDataFrame, 21 | bounds: gpd.GeoSeries, 22 | CRS: Optional[pyproj.CRS] = None, 23 | plt_ax: Optional[plt.axes] = None, 24 | plt_show: bool = True, 25 | visualization_column: Optional[str] = None, 26 | bounds_color: Optional[Union[str, np.array, pd.Series]] = None, 27 | detection_kwargs: dict = {}, 28 | bounds_kwargs: dict = {}, 29 | raster_file: Optional[PATH_TYPE] = None, 30 | raster_vis_downsample: float = 10.0, 31 | ) -> plt.axes: 32 | """Plot the detections and the bounds of the region 33 | 34 | Args: 35 | data_frame: (gpd.GeoDataFrame): 36 | The data representing the detections 37 | bounds: (gpd.GeoSeries): 38 | The spatial bounds of the predicted region 39 | CRS (Optional[pyproj.CRS], optional): 40 | What CRS to use. Defaults to None. 41 | as_pixels (bool, optional): 42 | Whether to display in pixel coordinates. Defaults to False. 43 | plt_ax (Optional[plt.axes], optional): 44 | A pyplot axes to plot on. If not provided, one will be created. Defaults to None. 45 | plt_show (bool, optional): 46 | Whether to plot the result or just return it. Defaults to True. 47 | visualization_column (Optional[str], optional): 48 | Which column to visualize from the detections dataframe. Defaults to None. 49 | bounds_color (Optional[Union[str, np.array, pd.Series]], optional): 50 | The color to plot the bounds. Must be accepted by the gpd.plot color argument. 51 | Defaults to None. 52 | detection_kwargs (dict, optional): 53 | Additional keyword arguments to pass to the .plot method for the detections. 54 | Defaults to {}. 55 | bounds_kwargs (dict, optional): 56 | Additional keyword arguments to pass to the .plot method for the bounds. 57 | Defaults to {}. 58 | raster_file (Optional[PATH_TYPE], optional): 59 | A path to a raster file to visualize the detections over if provided. Defaults to None. 60 | raster_vis_downsample (float, optional): 61 | The raster file is downsampled by this fraction before visualization to avoid 62 | excessive memory use or plotting time. Defaults to 10.0. 63 | 64 | Returns: 65 | plt.axes: The axes that were plotted on 66 | """ 67 | 68 | # If no axes are provided, create new ones 69 | if plt_ax is None: 70 | _, plt_ax = plt.subplots() 71 | 72 | # Show the raster if provided 73 | if raster_file is not None: 74 | show_raster( 75 | raster_file_path=raster_file, 76 | downsample_factor=raster_vis_downsample, 77 | plt_ax=plt_ax, 78 | CRS=CRS, 79 | ) 80 | 81 | # Plot the detections dataframe and the bounds on the same axes 82 | if "facecolor" not in detection_kwargs: 83 | # Plot with transperent faces unless requested 84 | detection_kwargs["facecolor"] = "none" 85 | 86 | data_frame.plot( 87 | ax=plt_ax, column=visualization_column, **detection_kwargs, legend=True 88 | ) 89 | # Use the .boundary attribute to plot just the border. This works since it's a geoseries, 90 | # not a geodataframe 91 | bounds.boundary.plot(ax=plt_ax, color=bounds_color, **bounds_kwargs) 92 | 93 | # Show if requested 94 | if plt_show: 95 | plt.show() 96 | 97 | # Return the axes in case they need to be used later 98 | return plt_ax 99 | 100 | 101 | class RegionDetections: 102 | detections: gpd.GeoDataFrame 103 | pixel_to_CRS_transform: rasterio.transform.AffineTransformer 104 | prediction_bounds_in_CRS: Union[shapely.Polygon, shapely.MultiPolygon, None] 105 | 106 | def __init__( 107 | self, 108 | detection_geometries: List[shapely.Geometry] | str, 109 | data: Union[dict, pd.DataFrame] = {}, 110 | input_in_pixels: bool = True, 111 | CRS: Optional[Union[pyproj.CRS, rasterio.CRS]] = None, 112 | pixel_to_CRS_transform: Optional[rasterio.transform.AffineTransformer] = None, 113 | pixel_prediction_bounds: Optional[ 114 | shapely.Polygon | shapely.MultiPolygon 115 | ] = None, 116 | geospatial_prediction_bounds: Optional[ 117 | shapely.Polygon | shapely.MultiPolygon 118 | ] = None, 119 | ): 120 | """Create a region detections object 121 | 122 | Args: 123 | detection_geometries (List[shapely.Geometry] | str | None): 124 | A list of shapely geometries for each detection. Alternatively, can be a string 125 | represting a key in data providing the same, or None if that key is named "geometry". 126 | The coordinates can either be provided in pixel coordinates or in the coordinates 127 | of a CRS. input_in_pixels should be set accordingly. 128 | data (Optional[dict | pd.DataFrame], optional): 129 | A dictionary mapping from str names for an attribute to a list of values for that 130 | attribute, one value per detection. Or a pandas dataframe. Passed to the data 131 | argument of gpd.GeoDataFrame. Defaults to {}. 132 | input_in_pixels (bool, optional): 133 | Whether the detection_geometries should be interpreted in pixels or geospatial 134 | coordinates. 135 | CRS (Optional[pyproj.CRS], optional): 136 | A coordinate reference system to interpret the data in. If input_in_pixels is False, 137 | then the input data will be interpreted as values in this CRS. If input_in_pixels is 138 | True and CRS is None, then the data will be interpreted as pixel coordinates with no 139 | georeferencing information. If input_in_pixels is True and CRS is not None, then 140 | the data will be attempted to be geometrically transformed into the CRS using either 141 | pixel_to_CRS_transform if set, or the relationship between the pixel and geospatial 142 | bounds of the region. Defaults to None. 143 | pixel_to_CRS_transform (Optional[rasterio.transform.AffineTransformer], optional): 144 | An affine transformation mapping from the pixel coordinates to those of the CRS. 145 | Only meaningful if `CRS` is set as well. Defaults to None. 146 | pixel_prediction_bounds (Optional[shapely.Polygon | shapely.MultiPolygon], optional): 147 | The pixel bounds of the region that predictions were generated for. For example, a 148 | square starting at (0, 0) and extending to the size in pixels of the tile. 149 | Defaults to None. 150 | geospatial_prediction_bounds (Optional[shapely.Polygon | shapely.MultiPolygon], optional): 151 | Only meaningful if CRS is set. In that case, it represents the spatial bounds of the 152 | prediction region. If pixel_to_CRS_transform is None, and both pixel_- and 153 | geospatial_prediction_bounds are not None, then the two bounds will be used to 154 | compute the transform. Defaults to None. 155 | """ 156 | # Build a geopandas dataframe containing the geometries, additional attributes, and CRS 157 | self.detections = gpd.GeoDataFrame( 158 | data=data, geometry=detection_geometries, crs=CRS 159 | ) 160 | 161 | # If the pixel_to_CRS_transform is None but can be computed from the two bounds, do that 162 | if ( 163 | input_in_pixels 164 | and (pixel_to_CRS_transform is None) 165 | and (pixel_prediction_bounds is not None) 166 | and (geospatial_prediction_bounds is not None) 167 | ): 168 | # We assume that these two array are exactly corresponding, representing the same shape 169 | # in the two coordinate frames and also the same starting vertex. 170 | # Drop the last entry because it is a duplicate of the first one 171 | geospatial_corners_array = shapely.get_coordinates( 172 | geospatial_prediction_bounds 173 | )[:-1] 174 | pixel_corners_array = shapely.get_coordinates(pixel_prediction_bounds)[:-1] 175 | 176 | # If they don't have the same number of vertices, this can't be the case 177 | if len(geospatial_corners_array) != len(pixel_corners_array): 178 | raise ValueError("Bounds had different lengths") 179 | 180 | # Representing the correspondences as ground control points 181 | ground_control_points = [ 182 | rasterio.control.GroundControlPoint( 183 | col=pixel_vertex[0], 184 | row=pixel_vertex[1], 185 | x=geospatial_vertex[0], 186 | y=geospatial_vertex[1], 187 | ) 188 | for pixel_vertex, geospatial_vertex in zip( 189 | pixel_corners_array, geospatial_corners_array 190 | ) 191 | ] 192 | # Solve the affine transform that best transforms from the pixel to geospatial coordinates 193 | pixel_to_CRS_transform = rasterio.transform.from_gcps(ground_control_points) 194 | 195 | # Error checking 196 | if (pixel_to_CRS_transform is None) and (CRS is not None) and input_in_pixels: 197 | raise ValueError( 198 | "The input was in pixels and a CRS was specified but no geommetric transformation was provided to transform the pixel values to that CRS" 199 | ) 200 | 201 | # Set the transform 202 | self.pixel_to_CRS_transform = pixel_to_CRS_transform 203 | 204 | # If the inputs are provided in pixels, apply the transform to the predictions 205 | if input_in_pixels: 206 | # Get the transform in the format expected by shapely 207 | shapely_transform = pixel_to_CRS_transform.to_shapely() 208 | # Apply this transformation to the geometry of the dataframe 209 | self.detections.geometry = self.detections.geometry.affine_transform( 210 | matrix=shapely_transform 211 | ) 212 | 213 | # Handle the bounds 214 | # If the bounds are provided as geospatial coordinates, use those directly 215 | if geospatial_prediction_bounds is not None: 216 | prediction_bounds_in_CRS = geospatial_prediction_bounds 217 | # If the bounds are provided in pixels and a transform to geospatial is provided, use that 218 | # this assumes that a CRS is set based on previous checks 219 | elif (pixel_to_CRS_transform is not None) and ( 220 | pixel_prediction_bounds is not None 221 | ): 222 | prediction_bounds_in_CRS = affine_transform( 223 | geom=pixel_prediction_bounds, matrix=pixel_to_CRS_transform.to_shapely() 224 | ) 225 | # If there is no CRS and pixel bounds are provided, use these directly 226 | # The None CRS implies pixels, so this still has the intended meaning 227 | elif CRS is None and pixel_prediction_bounds is not None: 228 | prediction_bounds_in_CRS = pixel_prediction_bounds 229 | # Else set the bounds to None (unknown) 230 | else: 231 | prediction_bounds_in_CRS = None 232 | 233 | # Create a one-length geoseries for the bounds 234 | self.prediction_bounds_in_CRS = gpd.GeoSeries( 235 | data=[prediction_bounds_in_CRS], crs=CRS 236 | ) 237 | 238 | def subset_detections(self, detection_indices) -> "RegionDetections": 239 | """Return a new RegionDetections object with only the detections indicated by the indices 240 | 241 | Args: 242 | detection_indices: 243 | Which detections to include. Can be any type that can be passed to pd.iloc. 244 | 245 | Returns: 246 | RegionDetections: The subset of detections cooresponding to these indeices 247 | """ 248 | # Create a deep copy of the object 249 | subset_rd = copy.deepcopy(self) 250 | # Subset the detections dataframe to the requested rows 251 | subset_rd.detections = subset_rd.detections.iloc[detection_indices, :] 252 | 253 | return subset_rd 254 | 255 | def save(self, save_path: PATH_TYPE): 256 | """Saves the information to disk 257 | 258 | Args: 259 | save_path (PATH_TYPE): 260 | Path to a geofile to save the data to. The containing folder will be created if it 261 | doesn't exist. 262 | """ 263 | # Convert to a Path object and create the containing folder if not present 264 | save_path = Path(save_path) 265 | save_path.parent.mkdir(parents=True, exist_ok=True) 266 | 267 | # Save the detections to a file. Note that the bounds of the prediction region and 268 | # information about the transform to pixel coordinates are currently lost. 269 | self.detections.to_file(save_path) 270 | 271 | def get_data_frame( 272 | self, CRS: Optional[pyproj.CRS] = None, as_pixels: Optional[bool] = False 273 | ) -> gpd.GeoDataFrame: 274 | """Get the detections, optionally specifying a CRS or pixel coordinates 275 | 276 | Args: 277 | CRS (Optional[pyproj.CRS], optional): 278 | Requested CRS for the output detections. If un-set, the CRS of self.detections will 279 | be used. Defaults to None. 280 | as_pixels (Optional[bool], optional): 281 | Whether to return the values in pixel coordinates. Defaults to False. 282 | 283 | Returns: 284 | gpd.GeoDataFrame: Detections in the requested CRS or in pixel coordinates with a None .crs 285 | """ 286 | # If the data is requested in pixel coordinates, transform it appropriately 287 | if as_pixels: 288 | if (self.detections.crs is not None) and ( 289 | self.pixel_to_CRS_transform is None 290 | ): 291 | raise ValueError( 292 | "Pixel coordinates were requested but data is in geospatial units with no transformation to pixels" 293 | ) 294 | 295 | # Compute the inverse transform using ~ to map from the CRS to pixels instead of the 296 | # other way around. Also get this transform in the shapely convention. 297 | CRS_to_pixel_transform_shapely = (~self.pixel_to_CRS_transform).to_shapely() 298 | # Create a new geodataframe with the transformed coordinates 299 | # Start by copying the old dataframe 300 | pixel_coordinate_detections = self.detections.copy() 301 | # Since the units are pixels, it no longer has a CRS, so set it to None 302 | pixel_coordinate_detections.crs = None 303 | # Transform the geometry to pixel coordinates 304 | pixel_coordinate_detections.geometry = ( 305 | pixel_coordinate_detections.geometry.affine_transform( 306 | CRS_to_pixel_transform_shapely 307 | ) 308 | ) 309 | return pixel_coordinate_detections 310 | 311 | # Return the data in geospatial coordinates 312 | else: 313 | # If no CRS is specified, return the data as-is, using the current CRS 314 | if CRS is None: 315 | return self.detections.copy() 316 | 317 | # Transform the data to the requested CRS. Note that if no CRS is provided initially, 318 | # this will error out 319 | detections_in_new_CRS = self.detections.copy().to_crs(CRS) 320 | return detections_in_new_CRS 321 | 322 | def convert_to_bboxes(self) -> "RegionDetections": 323 | """ 324 | Return a copy of the RD with the geometry of all detections replaced by the minimum 325 | axis-aligned bounding rectangle. Also, rows with an empty geometry are removed. 326 | 327 | Returns: 328 | RegionDetections: An identical RD with all non-empty geometries represented as bounding boxes. 329 | """ 330 | # Get the detections 331 | detections_df = self.get_data_frame() 332 | # Get non-empty rows in the dataframe, since conversion to bounding box only works for 333 | # non-empty polygons 334 | nonempty_rows = detections_df[~detections_df.geometry.is_empty] 335 | # Get the bounds and convert to shapely boxes 336 | bounds = nonempty_rows.bounds 337 | boxes = shapely.box( 338 | xmin=bounds.minx, ymin=bounds.miny, xmax=bounds.maxx, ymax=bounds.maxy 339 | ) 340 | # Update the geometry 341 | # TODO make sure that thisn't updating the geometry of the orignal one 342 | nonempty_rows.geometry = boxes 343 | # Create a new RegionDetections object and update the detections on it 344 | bbox_rd = copy.deepcopy(self) 345 | bbox_rd.detections = nonempty_rows 346 | return bbox_rd 347 | 348 | def update_geometry_column(self, geometry_column: str) -> "RegionDetections": 349 | """Update the geometry to another column in the dataframe that contains shapely data 350 | 351 | Args: 352 | geometry_column (str): The name of a column containing shapely data 353 | 354 | Returns: 355 | RegionDetections: An updated RD with the geometry specified by the data in `geometry_column` 356 | """ 357 | # Create a copy of the detections 358 | detections_df = self.get_data_frame().copy() 359 | # Set the geometry column to the specified one 360 | detections_df.geometry = detections_df[geometry_column] 361 | 362 | # Create a copy of the RD 363 | updated_geometry_rd = copy.deepcopy(self) 364 | # Update the detections and return 365 | updated_geometry_rd.detections = detections_df 366 | 367 | return updated_geometry_rd 368 | 369 | def get_bounds( 370 | self, CRS: Optional[pyproj.CRS] = None, as_pixels: Optional[bool] = False 371 | ) -> gpd.GeoSeries: 372 | if CRS is None: 373 | # Get bounds in original CRS 374 | bounds = self.prediction_bounds_in_CRS.copy() 375 | else: 376 | # Get bounds in requested CRS 377 | bounds = self.prediction_bounds_in_CRS.to_crs(CRS) 378 | 379 | if as_pixels: 380 | # Invert the transform 381 | CRS_to_pixel_transform_shapely = (~self.pixel_to_CRS_transform).to_shapely() 382 | # Transform the bounds into pixel coordinates 383 | bounds.geometry = bounds.geometry.affine_transform( 384 | CRS_to_pixel_transform_shapely 385 | ) 386 | # Set the CRS to None since this is now pixels 387 | # Note that this does not reproject 388 | bounds.set_crs(None) 389 | 390 | return bounds 391 | 392 | def get_CRS(self) -> Union[pyproj.CRS, None]: 393 | """Return the CRS of the detections dataframe 394 | 395 | Returns: 396 | Union[pyproj.CRS, None]: The CRS for the detections 397 | """ 398 | return self.detections.crs 399 | 400 | def plot( 401 | self, 402 | CRS: Optional[pyproj.CRS] = None, 403 | as_pixels: bool = False, 404 | plt_ax: Optional[plt.axes] = None, 405 | plt_show: bool = True, 406 | visualization_column: Optional[str] = None, 407 | bounds_color: Optional[Union[str, np.array, pd.Series]] = None, 408 | detection_kwargs: dict = {}, 409 | bounds_kwargs: dict = {}, 410 | raster_file: Optional[PATH_TYPE] = None, 411 | raster_vis_downsample: float = 10.0, 412 | ) -> plt.axes: 413 | """Plot the detections and the bounds of the region 414 | 415 | Args: 416 | CRS (Optional[pyproj.CRS], optional): 417 | What CRS to use. Defaults to None. 418 | as_pixels (bool, optional): 419 | Whether to display in pixel coordinates. Defaults to False. 420 | plt_ax (Optional[plt.axes], optional): 421 | A pyplot axes to plot on. If not provided, one will be created. Defaults to None. 422 | plt_show (bool, optional): 423 | Whether to plot the result or just return it. Defaults to True. 424 | visualization_column (Optional[str], optional): 425 | Which column to visualize from the detections dataframe. Defaults to None. 426 | bounds_color (Optional[Union[str, np.array, pd.Series]], optional): 427 | The color to plot the bounds. Must be accepted by the gpd.plot color argument. 428 | Defaults to None. 429 | detection_kwargs (dict, optional): 430 | Additional keyword arguments to pass to the .plot method for the detections. 431 | Defaults to {}. 432 | bounds_kwargs (dict, optional): 433 | Additional keyword arguments to pass to the .plot method for the bounds. 434 | Defaults to {}. 435 | raster_file (Optional[PATH_TYPE], optional): 436 | A path to a raster file to visualize the detections over if provided. Defaults to None. 437 | raster_vis_downsample (float, optional): 438 | The raster file is downsampled by this fraction before visualization to avoid 439 | excessive memory use or plotting time. Defaults to 10.0. 440 | 441 | Returns: 442 | plt.axes: The axes that were plotted on 443 | """ 444 | 445 | # Get the dataframe and the bounds 446 | data_frame = self.get_data_frame(CRS=CRS, as_pixels=as_pixels) 447 | bounds = self.get_bounds(CRS, as_pixels=as_pixels) 448 | 449 | # Perform plotting and return the axes 450 | plot_detections( 451 | data_frame=data_frame, 452 | bounds=bounds, 453 | CRS=data_frame.crs, 454 | plt_ax=plt_ax, 455 | plt_show=plt_show, 456 | visualization_column=visualization_column, 457 | detection_kwargs=detection_kwargs, 458 | bounds_kwargs=bounds_kwargs, 459 | raster_file=raster_file, 460 | raster_vis_downsample=raster_vis_downsample, 461 | bounds_color=bounds_color, 462 | ) 463 | 464 | 465 | class RegionDetectionsSet: 466 | region_detections: List[RegionDetections] 467 | 468 | def __init__(self, region_detections: List[RegionDetections]): 469 | """Create a set of detections to conveniently perform operations on all of them, for example 470 | merging all regions into a single dataframe with an additional column indicating which 471 | region each detection belongs to. 472 | 473 | Args: 474 | region_detections (List[RegionDetections]): A list of individual detections 475 | """ 476 | self.region_detections = region_detections 477 | 478 | def all_regions_have_CRS(self) -> bool: 479 | """Check whether all sub-regions have a non-None CRS 480 | 481 | Returns: 482 | bool: Whether all sub-regions have a valid CRS. 483 | """ 484 | # Get the CRS for each sub-region 485 | regions_CRS_values = [rd.detections.crs for rd in self.region_detections] 486 | # Only valid if no CRS value is None 487 | valid = not (None in regions_CRS_values) 488 | 489 | return valid 490 | 491 | def get_default_CRS(self, check_all_have_CRS=True) -> pyproj.CRS: 492 | """Find the CRS of the first sub-region to use as a default 493 | 494 | Args: 495 | check_all_have_CRS (bool, optional): 496 | Should an error be raised if not all regions have a CRS set. 497 | 498 | Returns: 499 | pyproj.CRS: The CRS given by the first sub-region. 500 | """ 501 | if check_all_have_CRS and not self.all_regions_have_CRS(): 502 | raise ValueError( 503 | "Not all regions have a CRS set and a default one was requested" 504 | ) 505 | # Check that every region is geospatial 506 | regions_CRS_values = [rd.detections.crs for rd in self.region_detections] 507 | # The default is the to the CRS of the first region 508 | # TODO in the future it could be something else like the most common 509 | CRS = regions_CRS_values[0] 510 | 511 | return CRS 512 | 513 | def get_region_detections(self, index: int) -> RegionDetections: 514 | """Get a single region detections object by index 515 | 516 | Args: 517 | index (int): Which one to select 518 | 519 | Returns: 520 | RegionDetections: The RegionDetections object from the list of objects in the set 521 | """ 522 | return self.region_detections[index] 523 | 524 | def merge( 525 | self, 526 | region_ID_key: Optional[str] = "region_ID", 527 | CRS: Optional[pyproj.CRS] = None, 528 | ): 529 | """Get the merged detections across all regions with an additional field specifying which region 530 | the detection came from. 531 | 532 | Args: 533 | region_ID_key (Optional[str], optional): 534 | Create this column in the output dataframe identifying which region that data came 535 | from using a zero-indexed integer. Defaults to "region_ID". 536 | CRS (Optional[pyproj.CRS], optional): 537 | Requested CRS for merged detections. If un-set, the CRS of the first region will 538 | be used. Defaults to None. 539 | 540 | Returns: 541 | gpd.GeoDataFrame: Detections in the requested CRS or in pixel coordinates with a None .crs 542 | """ 543 | 544 | # Get the detections from each region detection object as geodataframes 545 | detection_geodataframes = [ 546 | rd.get_data_frame(CRS=CRS) for rd in self.region_detections 547 | ] 548 | 549 | # Add a column to each geodataframe identifying which region detection object it came from 550 | # Note that dataframes in the original list are updated 551 | # TODO consider a more sophisticated ID 552 | for ID, gdf in enumerate(detection_geodataframes): 553 | gdf[region_ID_key] = ID 554 | 555 | # Concatenate the geodataframes together 556 | concatenated_geodataframes = pd.concat(detection_geodataframes) 557 | 558 | ## Merge_bounds 559 | # Get the merged bounds 560 | merged_bounds = self.get_bounds(CRS=CRS) 561 | # Convert to a single shapely object 562 | merged_bounds_shapely = merged_bounds.geometry[0] 563 | 564 | # Use the geometry column of the concatenated dataframes 565 | merged_region_detections = RegionDetections( 566 | detection_geometries="geometry", 567 | data=concatenated_geodataframes, 568 | CRS=CRS, 569 | input_in_pixels=False, 570 | geospatial_prediction_bounds=merged_bounds_shapely, 571 | ) 572 | 573 | return merged_region_detections 574 | 575 | def get_data_frame( 576 | self, 577 | CRS: Optional[pyproj.CRS] = None, 578 | merge: bool = False, 579 | region_ID_key: str = "region_ID", 580 | ) -> gpd.GeoDataFrame | List[gpd.GeoDataFrame]: 581 | """Get the detections, optionally specifying a CRS 582 | 583 | Args: 584 | CRS (Optional[pyproj.CRS], optional): 585 | Requested CRS for the output detections. If un-set, the CRS of self.detections will 586 | be used. Defaults to None. 587 | merge (bool, optional): 588 | If true, return one dataframe. Else, return a list of individual dataframes. 589 | region_ID_key (str, optional): 590 | Use this column to identify which region each detection came from. Defaults to 591 | "region_ID" 592 | 593 | Returns: 594 | gpd.GeoDataFrame | List[gpd.GeoDataFrame]: 595 | If merge=True, then one dataframe with an addtional column specifying which region each 596 | detection came from. If merge=False, then a list of dataframes for each region. 597 | """ 598 | if merge: 599 | # Merge all of the detections into one RegionDetection 600 | merged_detections = self.merge(region_ID_key=region_ID_key, CRS=CRS) 601 | # get the dataframe. It is already in the requested CRS in the current implementation. 602 | data_frame = merged_detections.get_data_frame() 603 | return data_frame 604 | 605 | # Get a list of dataframes from each region 606 | list_of_region_data_frames = [ 607 | rd.get_data_frame(CRS=CRS) for rd in self.region_detections 608 | ] 609 | return list_of_region_data_frames 610 | 611 | def convert_to_bboxes(self) -> "RegionDetectionsSet": 612 | """Convert all the RegionDetections to bounding box representations 613 | 614 | Returns: 615 | RegionDetectionsSet: 616 | A new RDS where each RD has all empty geometries dropped and all remaning ones 617 | represented by an axis-aligned rectangle. 618 | """ 619 | # Convert each detection 620 | converted_detections = [rd.convert_to_bboxes() for rd in self.region_detections] 621 | # Return the new RDS 622 | bboxes_rds = RegionDetectionsSet(converted_detections) 623 | return bboxes_rds 624 | 625 | def update_geometry_column(self, geometry_column: str) -> "RegionDetectionsSet": 626 | """ 627 | Update the geometry to another column in the dataframe that contains shapely data for each RD 628 | 629 | Args: 630 | geometry_column (str): The name of a column containing shapely data 631 | 632 | Returns: 633 | RegionDetectionsSet: An updated RDS with the geometry specified by the data in `geometry_column` 634 | """ 635 | # Convert each detection 636 | converted_detections = [ 637 | rd.update_geometry_column(geometry_column=geometry_column) 638 | for rd in self.region_detections 639 | ] 640 | # Return the new RDS 641 | updated_geometry_rds = RegionDetectionsSet(converted_detections) 642 | return updated_geometry_rds 643 | 644 | def get_bounds( 645 | self, CRS: Optional[pyproj.CRS] = None, union_bounds: bool = True 646 | ) -> gpd.GeoSeries: 647 | """Get the bounds corresponding to the sub-regions. 648 | 649 | Args: 650 | CRS (Optional[pyproj.CRS], optional): 651 | The CRS to return the bounds in. If not set, it will be the bounds of the first 652 | region. Defaults to None. 653 | union_bounds (bool, optional): 654 | Whether to return the spatial union of all bounds or a series of per-region bounds. 655 | Defaults to True. 656 | 657 | Returns: 658 | gpd.GeoSeries: Either a one-length series of merged bounds if merge=True or a series 659 | of bounds per region. 660 | """ 661 | 662 | region_bounds = [rd.get_bounds(CRS=CRS) for rd in self.region_detections] 663 | # Create a geodataframe out of these region bounds 664 | all_region_bounds = gpd.GeoSeries(pd.concat(region_bounds), crs=CRS) 665 | 666 | # If the union is not requested, return the individual bounds 667 | if not union_bounds: 668 | return all_region_bounds 669 | 670 | # Compute the union of all bounds 671 | merged_bounds = gpd.GeoSeries([all_region_bounds.geometry.union_all()], crs=CRS) 672 | 673 | return merged_bounds 674 | 675 | def disjoint_bounds(self) -> bool: 676 | """Determine whether the bounds of the sub-regions are disjoint 677 | 678 | Returns: 679 | bool: Are they disjoint 680 | """ 681 | # Get the bounds for each individual region 682 | bounds = self.get_bounds(union_bounds=False) 683 | # Get the union of all bounds 684 | union_bounds = bounds.union_all() 685 | 686 | # Find the sum of areas for each region 687 | sum_individual_areas = bounds.area.sum() 688 | # And the area of the union 689 | union_area = union_bounds.area 690 | 691 | # If the two areas are the same (down to numeric errors) then there are no overlaps 692 | disjoint = np.allclose(sum_individual_areas, union_area) 693 | 694 | return disjoint 695 | 696 | def save( 697 | self, 698 | save_path: PATH_TYPE, 699 | CRS: Optional[pyproj.CRS] = None, 700 | region_ID_key: Optional[str] = "region_ID", 701 | ): 702 | """ 703 | Save the data to a geospatial file by calling get_data_frame with merge=True and then saving 704 | to the specified file. The containing folder is created if it doesn't exist. 705 | 706 | Args: 707 | save_path (PATH_TYPE): 708 | File to save the data to. The containing folder will be created if it does not exist. 709 | CRS (Optional[pyproj.CRS], optional): 710 | See get_data_frame. 711 | region_ID_key (Optional[str], optional): 712 | See get_data_frame. 713 | """ 714 | # Get the concatenated dataframes 715 | concatenated_geodataframes = self.get_data_frame( 716 | CRS=CRS, region_ID_key=region_ID_key, merge=True 717 | ) 718 | 719 | # Ensure that the folder to save them to exists 720 | save_path = Path(save_path) 721 | save_path.parent.mkdir(exist_ok=True, parents=True) 722 | 723 | # Save the data to the geofile 724 | concatenated_geodataframes.to_file(save_path) 725 | 726 | def plot( 727 | self, 728 | CRS: Optional[pyproj.CRS] = None, 729 | plt_ax: Optional[plt.axes] = None, 730 | plt_show: bool = True, 731 | visualization_column: Optional[str] = None, 732 | bounds_color: Optional[Union[str, np.array, pd.Series]] = None, 733 | detection_kwargs: dict = {}, 734 | bounds_kwargs: dict = {}, 735 | raster_file: Optional[PATH_TYPE] = None, 736 | raster_vis_downsample: float = 10.0, 737 | ) -> plt.axes: 738 | """Plot each of the region detections using their .plot method 739 | 740 | Args: 741 | CRS (Optional[pyproj.CRS], optional): 742 | The CRS to use for plotting all regions. If unset, the default one for this object 743 | will be selected. Defaults to None. 744 | plt_ax (Optional[plt.axes], optional): 745 | The axes to plot on. Will be created if not provided. Defaults to None. 746 | plt_show (bool, optional): 747 | See RegionDetections.plot. Defaults to True. 748 | visualization_column (Optional[str], optional): 749 | See regiondetections.plot. Defaults to None. 750 | bounds_color (Optional[Union[str, np.array, pd.Series]], optional): 751 | See regiondetections.plot. Defaults to None. 752 | detection_kwargs (dict, optional): 753 | See regiondetections.plot. Defaults to {}. 754 | bounds_kwargs (dict, optional): 755 | See regiondetections.plot. Defaults to {}. 756 | raster_file (Optional[PATH_TYPE], optional): 757 | See regiondetections.plot. Defaults to None. 758 | raster_vis_downsample (float, optional): 759 | See regiondetections.plot. Defaults to 10.0. 760 | 761 | Returns: 762 | plt.axes: The axes that have been plotted on. 763 | """ 764 | # Extract the bounds for each of the sub-regions 765 | bounds = self.get_bounds(CRS=CRS, union_bounds=False) 766 | data_frame = self.get_data_frame(CRS=CRS, merge=True) 767 | 768 | # Perform plotting and return the axes 769 | return plot_detections( 770 | data_frame=data_frame, 771 | bounds=bounds, 772 | CRS=data_frame.crs, 773 | plt_ax=plt_ax, 774 | plt_show=plt_show, 775 | visualization_column=visualization_column, 776 | bounds_color=bounds_color, 777 | detection_kwargs=detection_kwargs, 778 | bounds_kwargs=bounds_kwargs, 779 | raster_file=raster_file, 780 | raster_vis_downsample=raster_vis_downsample, 781 | ) 782 | -------------------------------------------------------------------------------- /tree_detection_framework/entrypoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/entrypoints/__init__.py -------------------------------------------------------------------------------- /tree_detection_framework/entrypoints/generate_predictions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from typing import Optional 4 | 5 | import pyproj 6 | import torch 7 | 8 | from tree_detection_framework.constants import BOUNDARY_TYPE, PATH_TYPE 9 | from tree_detection_framework.detection.detector import ( 10 | DeepForestDetector, 11 | Detectree2Detector, 12 | ) 13 | from tree_detection_framework.detection.models import DeepForestModule, Detectree2Module 14 | from tree_detection_framework.postprocessing.postprocessing import multi_region_NMS 15 | from tree_detection_framework.preprocessing.preprocessing import create_dataloader 16 | 17 | 18 | def generate_predictions( 19 | raster_folder_path: PATH_TYPE, 20 | chip_size: float, 21 | tree_detection_model: str, 22 | chip_stride: Optional[float] = None, 23 | chip_overlap_percentage: float = None, 24 | use_units_meters: bool = False, 25 | region_of_interest: Optional[BOUNDARY_TYPE] = None, 26 | output_resolution: Optional[float] = None, 27 | output_CRS: Optional[pyproj.CRS] = None, 28 | predictions_save_path: Optional[PATH_TYPE] = None, 29 | view_predictions_plot: bool = False, 30 | run_nms: bool = True, 31 | iou_threshold: Optional[float] = 0.3, 32 | min_confidence: Optional[float] = 0.3, 33 | batch_size: int = 1, 34 | ): 35 | """ 36 | Entrypoint script to generate tree detections for a raster dataset input. Supports visualizing and saving predictions. 37 | 38 | Args: 39 | raster_folder_path (PATH_TYPE): Path to the folder or raster files. 40 | chip_size (float): 41 | Dimension of the chip. May be pixels or meters, based on `use_units_meters`. 42 | chip_stride (Optional[float], optional): 43 | Stride of the chip. May be pixels or meters, based on `use_units_meters`. If used, 44 | `chip_overlap_percentage` should not be set. Defaults to None. 45 | tree_detection_model (str): 46 | Selected model for detecting trees. 47 | chip_overlap_percentage (Optional[float], optional): 48 | Percent overlap of the chip from 0-100. If used, `chip_stride` should not be set. 49 | Defaults to None. 50 | use_units_meters (bool, optional): 51 | Use units of meters rather than pixels when interpreting the `chip_size` and `chip_stride`. 52 | Defaults to False. 53 | region_of_interest (Optional[BOUNDARY_TYPE], optional): 54 | Only data from this spatial region will be included in the dataloader. Defaults to None. 55 | output_resolution (Optional[float], optional): 56 | Spatial resolution the data in meters/pixel. If un-set, will be the resolution of the 57 | first raster data that is read. Defaults to None. 58 | output_CRS: (Optional[pyproj.CRS], optional): 59 | The coordinate reference system to use for the output data. If un-set, will be the CRS 60 | of the first tile found. Defaults to None. 61 | predictions_save_path (Optional[PATH_TYPE], optional): 62 | Path to a geofile to save the prediction outputs. 63 | view_predictions_plot (bool, optional): 64 | Set to True if visualization of the detected regions is needed. Defaults to False. 65 | run_nms: (bool, optional): 66 | Set to True if non-max suppresion needs to be run on predictions from multiple regions. 67 | iou_threshold (float, optional): 68 | What intersection over union value to consider an overlapping detection. Defaults to 0.5. 69 | min_confidence (float, optional): 70 | Prediction score threshold for detections to be included. 71 | batch_size (int, optional): 72 | Number of images to load in a batch. Defaults to 1. 73 | """ 74 | 75 | # Create the dataloader by passing folder path to raster data. 76 | dataloader = create_dataloader( 77 | raster_folder_path=raster_folder_path, 78 | chip_size=chip_size, 79 | chip_stride=chip_stride, 80 | chip_overlap_percentage=chip_overlap_percentage, 81 | use_units_meters=use_units_meters, 82 | region_of_interest=region_of_interest, 83 | output_resolution=output_resolution, 84 | output_CRS=output_CRS, 85 | batch_size=batch_size, 86 | ) 87 | 88 | # Setup the specified tree detection model 89 | if tree_detection_model == "deepforest": 90 | 91 | # Setup the parameters dictionary 92 | param_dict = { 93 | "backbone": "retinanet", 94 | "num_classes": 1, 95 | } 96 | 97 | df_module = DeepForestModule(param_dict) 98 | # Move the module to the GPU if available 99 | df_module.to( 100 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 101 | ) 102 | lightning_detector = DeepForestDetector(df_module) 103 | 104 | elif tree_detection_model == "detectree2": 105 | 106 | # Load detectree2 pretrained weights 107 | # TODO: download pretrained weights when called, instead of providing a local path 108 | trained_model = ( 109 | "/ofo-share/repos-amritha/detectree2-code/230103_randresize_full.pth" 110 | ) 111 | param_dict = {"update_model": trained_model} 112 | 113 | dtree2_module = Detectree2Module(param_dict) 114 | lightning_detector = Detectree2Detector(dtree2_module) 115 | 116 | else: 117 | raise ValueError( 118 | """Please enter a valid tree detection model. Currently supported models are: 119 | 1. deepforest 120 | 2. detectree2""" 121 | ) 122 | 123 | # Get predictions by invoking the tree_detection_model 124 | logging.info("Getting tree detections") 125 | outputs = lightning_detector.predict(dataloader) 126 | 127 | if run_nms is True: 128 | logging.info("Running non-max suppression") 129 | # Run non-max suppression on the detected regions 130 | outputs = multi_region_NMS( 131 | outputs, iou_theshold=iou_threshold, min_confidence=min_confidence 132 | ) 133 | 134 | if predictions_save_path: 135 | # Save predictions to disk 136 | outputs.save(predictions_save_path) 137 | 138 | if view_predictions_plot is True: 139 | logging.info("View plot. Kill the plot window to exit.") 140 | # Plot the detections and the bounds of the region 141 | outputs.plot() 142 | 143 | 144 | def parse_args() -> argparse.Namespace: 145 | description = ( 146 | "This script generates tree detections for a given raster image. First, it creates a dataloader " 147 | + "with the tiled raster dataset and provides the images as input to the selected tree detection model. " 148 | + "All of the arguments are passed to " 149 | + "tree_detection_framework.entrypoints.generate_predictions " 150 | + "which has the following documentation:\n\n" 151 | + generate_predictions.__doc__ 152 | ) 153 | parser = argparse.ArgumentParser( 154 | description=description, formatter_class=argparse.RawDescriptionHelpFormatter 155 | ) 156 | 157 | parser.add_argument("--raster-folder-path", required=True) 158 | parser.add_argument("--chip-size", type=float, required=True) 159 | parser.add_argument("--tree-detection-model", type=str, required=True) 160 | parser.add_argument("--chip-stride", type=float) 161 | parser.add_argument("--chip-overlap-percentage", type=float) 162 | parser.add_argument("--use-units-meters", action="store_true") 163 | parser.add_argument("--region-of-interest") 164 | parser.add_argument("--output-resolution", type=float) 165 | parser.add_argument("--output-CRS") 166 | parser.add_argument("--predictions-save-path") 167 | parser.add_argument("--view-predictions-plot", action="store_true") 168 | parser.add_argument("--run-nms", action="store_true") 169 | parser.add_argument("--iou-threshold", type=float, default=0.3) 170 | parser.add_argument("--min-confidence", type=float, default=0.3) 171 | parser.add_argument("--batch-size", type=int, default=1) 172 | 173 | try: 174 | args = parser.parse_args() 175 | 176 | except SystemExit as e: 177 | print("\nError: Missing required arguments.") 178 | parser.print_help() 179 | raise e 180 | 181 | return args 182 | 183 | 184 | if __name__ == "__main__": 185 | args = parse_args() 186 | generate_predictions(**args.__dict__) 187 | -------------------------------------------------------------------------------- /tree_detection_framework/entrypoints/tile_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Optional 3 | 4 | import pyproj 5 | 6 | from tree_detection_framework.constants import ARRAY_TYPE, BOUNDARY_TYPE, PATH_TYPE 7 | from tree_detection_framework.preprocessing.preprocessing import ( 8 | create_dataloader, 9 | save_dataloader_contents, 10 | visualize_dataloader, 11 | ) 12 | 13 | 14 | def tile_data( 15 | raster_folder_path: PATH_TYPE, 16 | chip_size: float, 17 | chip_stride: Optional[float] = None, 18 | chip_overlap_percentage: float = None, 19 | use_units_meters: bool = False, 20 | region_of_interest: Optional[BOUNDARY_TYPE] = None, 21 | output_resolution: Optional[float] = None, 22 | output_CRS: Optional[pyproj.CRS] = None, 23 | vector_label_folder_path: Optional[PATH_TYPE] = None, 24 | vector_label_attribute: Optional[str] = None, 25 | visualize_n_tiles: Optional[int] = None, 26 | save_folder: Optional[PATH_TYPE] = None, 27 | save_n_tiles: Optional[int] = None, 28 | random_sample: bool = False, 29 | batch_size: int = 1, 30 | ): 31 | """ 32 | Entrypoint script for testing preprocessing functions. 33 | It enables creating a dataloader, visualizing sample tiles from the dataloader, and saving the contents of the dataloader to disk. 34 | 35 | Args: 36 | raster_folder_path (PATH_TYPE): Path to the folder or raster files. 37 | chip_size (float): 38 | Dimension of the chip. May be pixels or meters, based on `use_units_meters`. 39 | chip_stride (Optional[float], optional): 40 | Stride of the chip. May be pixels or meters, based on `use_units_meters`. If used, 41 | `chip_overlap_percentage` should not be set. Defaults to None. 42 | chip_overlap_percentage (Optional[float], optional): 43 | Percent overlap of the chip from 0-100. If used, `chip_stride` should not be set. 44 | Defaults to None. 45 | use_units_meters (bool, optional): 46 | Use units of meters rather than pixels when interpreting the `chip_size` and `chip_stride`. 47 | Defaults to False. 48 | region_of_interest (Optional[BOUNDARY_TYPE], optional): 49 | Only data from this spatial region will be included in the dataloader. Defaults to None. 50 | output_resolution (Optional[float], optional): 51 | Spatial resolution the data in meters/pixel. If un-set, will be the resolution of the 52 | first raster data that is read. Defaults to None. 53 | output_CRS: (Optional[pyproj.CRS], optional): 54 | The coordinate reference system to use for the output data. If un-set, will be the CRS 55 | of the first tile found. Defaults to None. 56 | vector_label_folder_path (Optional[PATH_TYPE], optional): 57 | A folder of geospatial vector files that will be used for the label. If un-set, the 58 | dataloader will not be labeled. Defaults to None. 59 | vector_label_attribute (Optional[str], optional): 60 | Attribute to read from the vector data, such as the class or instance ID. Defaults to None. 61 | visualize_n_tiles (int): 62 | The number of randomly-sampled tiles to display. 63 | save_folder (Optional[PATH_TYPE], optional): 64 | Folder to save data to. Will be created if it doesn't exist. 65 | save_n_tiles (Optional[int], optional): 66 | Number of tiles to save. Whether they are the first tiles or random is controlled by 67 | `random_sample`. If unset, all tiles will be saved. Defaults to None. 68 | random_sample (bool, optional): 69 | If `save_n_tiles` is set, should the tiles be randomly sampled rather than taken from the 70 | beginning of the dataloader. Defaults to False. 71 | batch_size (int, optional): 72 | Number of images to load in a batch. Defaults to 1. 73 | """ 74 | # Create the dataloader by passing folder path to raster data and optionally a path to the vector data folder. 75 | dataloader = create_dataloader( 76 | raster_folder_path=raster_folder_path, 77 | chip_size=chip_size, 78 | chip_stride=chip_stride, 79 | chip_overlap_percentage=chip_overlap_percentage, 80 | use_units_meters=use_units_meters, 81 | region_of_interest=region_of_interest, 82 | output_resolution=output_resolution, 83 | output_CRS=output_CRS, 84 | vector_label_folder_path=vector_label_folder_path, 85 | vector_label_attribute=vector_label_attribute, 86 | batch_size=batch_size, 87 | ) 88 | 89 | # If `visualize_n_tiles` is specified, display those many number of tiles. 90 | if visualize_n_tiles is not None: 91 | visualize_dataloader(dataloader=dataloader, n_tiles=visualize_n_tiles) 92 | 93 | # If path to save tiles is given, save all the tiles (or `n_tiles`) from the dataloader to disk. Tiles can be randomly sampled or ordered. 94 | if save_folder is not None: 95 | save_dataloader_contents( 96 | dataloader=dataloader, 97 | save_folder=save_folder, 98 | n_tiles=save_n_tiles, 99 | random_sample=random_sample, 100 | ) 101 | 102 | 103 | def parse_args() -> argparse.Namespace: 104 | parser = argparse.ArgumentParser(description="Chipping orthomosaic images") 105 | 106 | parser.add_argument( 107 | "--raster-folder-path", 108 | type=str, 109 | required=True, 110 | help="Path to the folder or raster files.", 111 | ) 112 | 113 | parser.add_argument( 114 | "--chip-size", 115 | type=float, 116 | required=True, 117 | help="Dimension of the chip. May be pixels or meters, based on --use-units-meters.", 118 | ) 119 | 120 | parser.add_argument( 121 | "--chip-stride", 122 | type=float, 123 | required=False, 124 | help="Stride of the chip. May be pixels or meters, based on --use-units-meters. If used, --chip-overlap-percentage should not be set.", 125 | ) 126 | 127 | parser.add_argument( 128 | "--chip-overlap-percentage", 129 | type=float, 130 | required=False, 131 | help="Percent overlap of the chip from 0-100. If used, --chip-stride should not be set.", 132 | ) 133 | 134 | parser.add_argument( 135 | "--use-units-meters", 136 | action="store_true", 137 | help="Use units of meters rather than pixels when interpreting the --chip-size and --chip-stride.", 138 | ) 139 | 140 | parser.add_argument( 141 | "--region-of-interest", 142 | type=str, 143 | required=False, 144 | help="Only data from this spatial region will be included in the dataloader. Should be specified as minx,miny,maxx,maxy.", 145 | ) 146 | 147 | parser.add_argument( 148 | "--output-resolution", 149 | type=float, 150 | required=False, 151 | help="Spatial resolution of the data in meters/pixel. If un-set, will be the resolution of the first raster data that is read.", 152 | ) 153 | 154 | parser.add_argument( 155 | "--output-CRS", 156 | type=str, 157 | required=False, 158 | help="The coordinate reference system to use for the output data. If un-set, will be the CRS of the first tile found.", 159 | ) 160 | 161 | parser.add_argument( 162 | "--vector-label-folder-path", 163 | type=str, 164 | required=False, 165 | help="A folder of geospatial vector files that will be used for the label. If un-set, the dataloader will not be labeled.", 166 | ) 167 | 168 | parser.add_argument( 169 | "--vector-label-attribute", 170 | type=str, 171 | default="treeID", 172 | help="Attribute to read from the vector data, such as the class or instance ID. Defaults to 'treeID'.", 173 | ) 174 | 175 | parser.add_argument( 176 | "--visualize-n-tiles", 177 | type=int, 178 | required=False, 179 | help="The number of randomly-sampled tiles to display.", 180 | ) 181 | 182 | parser.add_argument( 183 | "--save-folder", 184 | type=str, 185 | required=False, 186 | help="Folder to save data to. Will be created if it doesn't exist.", 187 | ) 188 | 189 | parser.add_argument( 190 | "--save-n-tiles", 191 | type=int, 192 | required=False, 193 | help="Number of tiles to save. Whether they are the first tiles or random is controlled by --random-sample. If unset, all tiles will be saved.", 194 | ) 195 | 196 | parser.add_argument( 197 | "--random-sample", 198 | action="store_true", 199 | help="If --save-n-tiles is set, should the tiles be randomly sampled rather than taken from the beginning of the dataloader.", 200 | ) 201 | 202 | parser.add_argument( 203 | "--batch-size", 204 | type=int, 205 | required=False, 206 | help="Number of images to load in a batch. Defaults to 1.", 207 | ) 208 | 209 | args = parser.parse_args() 210 | return args 211 | 212 | 213 | if __name__ == "__main__": 214 | args = parse_args() 215 | tile_data(**args.__dict__) 216 | -------------------------------------------------------------------------------- /tree_detection_framework/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/evaluation/__init__.py -------------------------------------------------------------------------------- /tree_detection_framework/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import geopandas as gpd 4 | import numpy as np 5 | from scipy.optimize import linear_sum_assignment 6 | from shapely.geometry import Polygon 7 | 8 | 9 | def compute_matched_ious( 10 | ground_truth_boxes: List[Polygon], predicted_boxes: List[Polygon] 11 | ) -> List: 12 | """Compute IoUs for matched pairs of ground truth and predicted boxes. 13 | This uses the Hungarian algorithm to find the optimal assignment. 14 | Args: 15 | ground_truth_boxes (list): List of ground truth polygons. 16 | predicted_boxes (list): List of predicted polygons. 17 | Returns: 18 | list: List of IoUs for matched pairs. 19 | """ 20 | if not ground_truth_boxes or not predicted_boxes: 21 | return 0.0 # Return 0 if either list is empty 22 | 23 | # Create GeoDataFrames for ground truth and predicted boxes 24 | gt_gdf = gpd.GeoDataFrame(geometry=ground_truth_boxes) 25 | gt_gdf["area_gt"] = gt_gdf.geometry.area 26 | gt_gdf["id_gt"] = gt_gdf.index 27 | 28 | pred_gdf = gpd.GeoDataFrame(geometry=predicted_boxes) 29 | pred_gdf["area_pred"] = pred_gdf.geometry.area 30 | pred_gdf["id_pred"] = pred_gdf.index 31 | 32 | # Get the intersection between the two sets of polygons 33 | intersection = gpd.overlay(gt_gdf, pred_gdf, how="intersection") 34 | intersection["iou"] = intersection.area / ( 35 | intersection["area_gt"] + (intersection["area_pred"] - intersection.area) 36 | ) 37 | 38 | # Create a cost matrix to store IoUs 39 | cost_matrix = np.zeros((len(gt_gdf), len(pred_gdf))) 40 | cost_matrix[intersection["id_gt"], intersection["id_pred"]] = -intersection["iou"] 41 | 42 | # Solve optimal assignment using the Hungarian algorithm 43 | gt_indices, pred_indices = linear_sum_assignment(cost_matrix) 44 | 45 | # Extract IoUs of matched pairs 46 | matched_ious = [-cost_matrix[i, j] for i, j in zip(gt_indices, pred_indices)] 47 | return matched_ious 48 | 49 | 50 | def compute_precision_recall( 51 | ious: List, num_gt: int, num_pd: int, threshold: float = 0.4 52 | ) -> tuple: 53 | """Compute precision and recall based on IoUs. 54 | Args: 55 | ious (list): List of IoUs for matched pairs. 56 | num_gt (int): Number of ground truth boxes. 57 | num_pd (int): Number of predicted boxes. 58 | threshold (float): IoU threshold for considering a match. 59 | Returns: 60 | tuple: Precision and recall values. 61 | """ 62 | true_positives = (np.array(ious) > threshold).astype(np.uint8) 63 | tp = np.sum(true_positives) 64 | recall = tp / num_gt if num_gt > 0 else 0.0 65 | precision = tp / num_pd if num_pd > 0 else 0.0 66 | return precision, recall 67 | -------------------------------------------------------------------------------- /tree_detection_framework/postprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/postprocessing/__init__.py -------------------------------------------------------------------------------- /tree_detection_framework/postprocessing/postprocessing.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | import pyproj 6 | from polygone_nms import nms 7 | from shapely import box 8 | from shapely.geometry import MultiPolygon, Polygon 9 | from shapely.ops import unary_union 10 | 11 | from tree_detection_framework.detection.region_detections import ( 12 | RegionDetections, 13 | RegionDetectionsSet, 14 | ) 15 | 16 | 17 | def single_region_NMS( 18 | detections: RegionDetections, 19 | threshold: float = 0.5, 20 | confidence_column: str = "score", 21 | min_confidence: float = 0.3, 22 | intersection_method: str = "IOU", 23 | ) -> RegionDetections: 24 | """Run non-max suppresion on predictions from a single region. 25 | 26 | Args: 27 | detections (RegionDetections): 28 | Detections from a single region to run NMS on. 29 | threshold (float, optional): 30 | The threshold for the NMS(intersection) method. Defaults to 0.5. 31 | confidence_column (str, optional): 32 | Which column in the dataframe to use as a confidence for NMS. Defaults to "score" 33 | min_confidence (float, optional): 34 | Prediction score threshold for detections to be included. 35 | intersection_method (str, optional): 36 | The method to compute intersections, one of ("IOU", "IOS", "Dice", "IOT"). Defaults to "IOU". 37 | 38 | Returns: 39 | RegionDetections: 40 | NMS-suppressed set of detections 41 | """ 42 | # Extract the geodataframe for the detections 43 | detections_df = detections.get_data_frame() 44 | 45 | # Determine which detections are high enough confidence to retain 46 | # Get rows that are both high confidence and not empty 47 | not_empty_mask = ~detections_df.geometry.is_empty 48 | high_conf_not_empty_inds = np.where( 49 | ( 50 | (detections_df[confidence_column] >= min_confidence) & not_empty_mask 51 | ).to_numpy() 52 | )[0] 53 | 54 | # Filter detections based on minimum confidence score 55 | detections_df = detections_df.iloc[high_conf_not_empty_inds] 56 | if detections_df.empty: 57 | # Return empty if no detections pass threshold 58 | return detections.subset_detections([]) 59 | 60 | ## Get the polygons for each detection object 61 | polygons = detections_df.geometry.to_list() 62 | # Extract the score 63 | confidences = detections_df[confidence_column].to_numpy() 64 | 65 | # Put the data into the required format, list[(polygon, class, confidence)] 66 | # TODO consider adding a class, currently set to all ones 67 | input_data = list(zip(polygons, np.ones_like(confidences), confidences)) 68 | 69 | # Run polygon NMS 70 | keep_inds = nms( 71 | input_data=input_data, 72 | distributed=None, 73 | nms_method="Default", 74 | intersection_method=intersection_method, 75 | threshold=threshold, 76 | ) 77 | 78 | # We only performed NMS on the high-confidence detections, but we need the indices w.r.t. the 79 | # original data with all detections. Sort for convenience so data is not permuted. 80 | keep_inds_in_original = sorted(high_conf_not_empty_inds[keep_inds]) 81 | # Extract the detections that were kept 82 | subset_region_detections = detections.subset_detections(keep_inds_in_original) 83 | 84 | return subset_region_detections 85 | 86 | 87 | def multi_region_NMS( 88 | detections: RegionDetectionsSet, 89 | run_per_region_NMS: bool = True, 90 | threshold: float = 0.5, 91 | confidence_column: str = "score", 92 | min_confidence: float = 0.3, 93 | intersection_method: str = "IOU", 94 | ) -> RegionDetections: 95 | """Run non-max suppresion on predictions from multiple regions. 96 | 97 | Args: 98 | detections (RegionDetectionsSet): 99 | Detections from multiple regions to run NMS on. 100 | run_per_region_NMS (bool): 101 | Should nonmax-suppression be run on each region before the regions are merged. This may 102 | lead to a speedup if there is a large amount of within-region overlap. Defaults to True. 103 | threshold (float, optional): 104 | The threshold for the NMS(intersection) method. Defaults to 0.5. 105 | confidence_column (str, optional): 106 | Which column in the dataframe to use as a confidence for NMS. Defaults to "score" 107 | 108 | min_confidence (float, optional): 109 | Prediction score threshold for detections to be included. 110 | intersection_method (str, optional): 111 | The method to compute intersections, one of ("IOU", "IOS", "Dice", "IOT"). Defaults to "IOU". 112 | Returns: 113 | RegionDetections: 114 | NMS-suppressed set of detections, merged together for the set of regions. 115 | """ 116 | # Determine whether to run NMS individually on each region. 117 | if run_per_region_NMS: 118 | # Run NMS on each sub-region and then wrap this in a region detection set 119 | detections = RegionDetectionsSet( 120 | [ 121 | single_region_NMS( 122 | region_detections, 123 | threshold=threshold, 124 | confidence_column=confidence_column, 125 | min_confidence=min_confidence, 126 | intersection_method=intersection_method, 127 | ) 128 | for region_detections in detections.region_detections 129 | ] 130 | ) 131 | 132 | # Merge the detections into a single RegionDetections 133 | merged_detections = detections.merge() 134 | 135 | # If the bounds of the individual regions were disjoint, then no NMS needs to be applied across 136 | # the different regions 137 | if detections.disjoint_bounds(): 138 | logging.info("Bounds are disjoint, skipping across-region NMS") 139 | return merged_detections 140 | logging.info("Bound have overlap, running across-region NMS") 141 | 142 | # Run NMS on this merged RegionDetections 143 | NMS_suppressed_merged_detections = single_region_NMS( 144 | merged_detections, 145 | threshold=threshold, 146 | confidence_column=confidence_column, 147 | min_confidence=min_confidence, 148 | intersection_method=intersection_method, 149 | ) 150 | 151 | return NMS_suppressed_merged_detections 152 | 153 | 154 | def polygon_hole_suppression(polygon: Polygon, min_area_threshold: float = 20.0): 155 | """To remove holes within a polygon 156 | 157 | Args: 158 | polygon(shapely.Polygon): 159 | A shapely polygon object 160 | min_area_threshold(float): 161 | Remove holes within the polygons that have area smaller than this value 162 | 163 | Returns: 164 | shapely.Polygon: 165 | The equivalent polygon created after suppressing the holes 166 | """ 167 | list_interiors = [] 168 | # Iterate through interiors list which includes the holes 169 | for interior in polygon.interiors: 170 | interior_polygon = Polygon(interior) 171 | # If area of the hole is greater than the threshold, include it in the final output 172 | if interior_polygon.area > min_area_threshold: 173 | list_interiors.append(interior) 174 | 175 | # Return a new polygon with holes suppressed 176 | return Polygon(polygon.exterior.coords, holes=list_interiors) 177 | 178 | 179 | def single_region_hole_suppression( 180 | detections: RegionDetections, min_area_threshold: float = 20.0 181 | ): 182 | """Suppress polygon holes in a RegionDetections object. 183 | 184 | Args: 185 | detections (RegionDetections): 186 | Detections from a single region that needs suppression of polygon holes. 187 | min_area_threshold(float): 188 | Remove holes within the polygons that have area smaller than this value. 189 | 190 | Returns: 191 | RegionDetections: 192 | Detections after suppressing polygon holes. 193 | """ 194 | detections_df = detections.get_data_frame() 195 | modified_geometries = [] 196 | 197 | for tree_crown in detections_df.geometry.to_list(): 198 | # If tree_crown is a Polygon, directly do polygon hole suppression 199 | if isinstance(tree_crown, Polygon): 200 | clean_tree_crown = polygon_hole_suppression(tree_crown, min_area_threshold) 201 | # If it is a MultiPolygon, do polygon hole suppression for each polygon within it 202 | elif isinstance(tree_crown, MultiPolygon): 203 | clean_polygons = [] 204 | for polygon in tree_crown.geoms: 205 | clean_polygon = polygon_hole_suppression(polygon, min_area_threshold) 206 | clean_polygons.append(clean_polygon) 207 | # Create a new MultiPolygon with the suppressed polygons 208 | clean_tree_crown = MultiPolygon(clean_polygons) 209 | # For any other cases, create an empty polygon (just to be safe) 210 | else: 211 | clean_tree_crown = Polygon() 212 | 213 | # Add the cleaned polygon/multipolygon to a list 214 | modified_geometries.append(clean_tree_crown) 215 | 216 | # Set this list as the geometry column in the dataframe 217 | detections_df.geometry = modified_geometries 218 | # Return a new RegionDetections object created using the updated dataframe 219 | # TODO: Handle cases where the data is in pixels with no transform to geospatial 220 | return RegionDetections( 221 | detection_geometries=None, 222 | data=detections_df, 223 | input_in_pixels=False, 224 | CRS=detections.get_CRS(), 225 | ) 226 | 227 | 228 | def multi_region_hole_suppression( 229 | detections: RegionDetectionsSet, min_area_threshold: float = 20.0 230 | ): 231 | """Suppress polygon holes in a RegionDetectionsSet object. 232 | 233 | Args: 234 | detections (RegionDetectionsSet): 235 | Set of detections from a multiple regions that need suppression of polygon holes. 236 | min_area_threshold(float): 237 | Remove holes within the polygons that have area smaller than this value. 238 | 239 | Returns: 240 | RegionDetectionsSet: 241 | Set of detections after suppressing polygon holes. 242 | """ 243 | # Perform single_region_hole_suppression for every region within the RegionDetectionsSet 244 | return RegionDetectionsSet( 245 | [ 246 | single_region_hole_suppression(region_detections, min_area_threshold) 247 | for region_detections in detections.region_detections 248 | ] 249 | ) 250 | 251 | 252 | def merge_and_postprocess_detections( 253 | detections: RegionDetectionsSet, 254 | tolerance: Optional[float] = 0.2, 255 | min_area_threshold: Optional[float] = 20.0, 256 | ) -> RegionDetections: 257 | """Apply postprocessing techniques that include: 258 | 1. Get a union of polygons that have been split across tiles 259 | 2. Simplify the edges of polygons by `tolerance` value 260 | 3. Remove holes within the polygons that are smaller than `min_area_threshold` value 261 | Merges regions into a single RegionDetections. 262 | 263 | Args: 264 | detections(RegionDetectionsSet): 265 | Detections from multiple regions to postprocess. 266 | tolerance (Optional[float], optional): 267 | A value that controls the simplification of the detection polygons. 268 | The higher this value, the smaller the number of vertices in the resulting geometry. 269 | min_area_threshold (Optional[float], optional): 270 | Holes within polygons having an area lesser than this value get removed. 271 | 272 | Returns: 273 | RegionDetections: 274 | Postprocessed set of detections, merged together for the set of regions. 275 | """ 276 | # Get the detections as a merged GeoDataFrame 277 | all_detections_gdf = detections.get_data_frame(merge=True) 278 | 279 | # Apply a small negative buffer to shrink polygons slightly 280 | buffered_geoms = [geom.buffer(-0.001) for geom in all_detections_gdf.geometry] 281 | 282 | # Compute the union of the set of polyogns. This step removes any vertical lines caused by the tile edges 283 | # and combines a single polygon that might have been split into multiple. Also removes any overlaps. 284 | union_detections = unary_union(buffered_geoms) 285 | 286 | # Simplify the polygons by tolerance value and extract only Polygons and MultiPolygons 287 | # since `union_detections` can have Point objects as well 288 | filtered_geoms = [ 289 | geom.simplify(tolerance) 290 | for geom in list(union_detections.geoms) 291 | if isinstance(geom, (Polygon, MultiPolygon)) 292 | ] 293 | 294 | # To remove small holes within polygons 295 | new_polygons = [] 296 | for polygon in filtered_geoms: 297 | new_polygon = polygon_hole_suppression(polygon, min_area_threshold) 298 | new_polygons.append(new_polygon) 299 | 300 | # Create a RegionDetections for the merged and postprocessed detections 301 | # TODO: Handle cases when input is in pixels 302 | postprocessed_detections = RegionDetections( 303 | new_polygons, input_in_pixels=False, CRS=all_detections_gdf.crs 304 | ) 305 | 306 | return postprocessed_detections 307 | 308 | 309 | def suppress_tile_boundary_with_NMS( 310 | predictions: RegionDetectionsSet, 311 | iou_threshold: float = 0.5, 312 | ios_threshold: float = 0.5, 313 | min_confidence: float = 0.3, 314 | ) -> RegionDetections: 315 | """ 316 | Used as a post-processing step with the `GeometricDetector` class to suppress detections that are split across tiles. 317 | This is done by applying NMS twice, first using IOU and then using IOS. 318 | 319 | Args: 320 | predictions (RegionDetectionsSet): 321 | Detections from multiple regions. 322 | iou_threshold (float, optional): 323 | The threshold for the NMS method that uses IoU metric. Defaults to 0.5. 324 | ios_threshold (float, optional): 325 | The threshold for the NMS method that uses IoS metric. Defaults to 0.5. 326 | min_confidence (float, optional): 327 | Prediction score threshold for detections to be included. 328 | 329 | Returns: 330 | RegionDetections: 331 | NMS postprocessed set of detections, merged together. 332 | """ 333 | 334 | iou_nms = multi_region_NMS( 335 | predictions, 336 | intersection_method="IOU", 337 | threshold=iou_threshold, 338 | min_confidence=min_confidence, 339 | ) 340 | 341 | iou_ios_nms = single_region_NMS( 342 | iou_nms, 343 | intersection_method="IOS", 344 | threshold=ios_threshold, 345 | min_confidence=min_confidence, 346 | ) 347 | 348 | return iou_ios_nms 349 | 350 | 351 | def remove_out_of_bounds_detections( 352 | region_detection_sets: List[RegionDetectionsSet], image_bounds: List 353 | ) -> List[RegionDetectionsSet]: 354 | """ 355 | Filters out detections that are outside the bounds of a defined region. 356 | Used as a post-processing step after `predict_raw_drone_images()`. 357 | 358 | Args: 359 | region_detection_sets (List[RegionDetectionSet]): 360 | Each elemet is a RegionDetectionsSet derived from a specific drone image. 361 | Length is the number of raw drone images given to the dataloader. 362 | image_bounds (List[bounding_box]): 363 | Each element is a `bounding_box` object derived from the dataloader. 364 | Length: number of regions in a set * number of RegionDetectionSet objects 365 | 366 | Returns: 367 | List of RegiondetectionSet objects with out-of-bounds predictions filtered out. 368 | """ 369 | 370 | region_idx = 0 # To index elements in true_bounds 371 | list_of_filtered_region_sets = [] 372 | 373 | for rds in region_detection_sets: 374 | 375 | # Find the number of regions in every set 376 | num_of_regions_in_a_set = len(rds.get_data_frame()) 377 | # To save filtered regions in a particular set 378 | list_of_filtered_regions = [] 379 | 380 | # Get the region image bounds for the RegionDetectionSet 381 | region_image_bounds = image_bounds[region_idx] 382 | 383 | for idx in range(num_of_regions_in_a_set): 384 | 385 | # Get RegionsDetections object from the set 386 | rd = rds.get_region_detections(idx) 387 | rd_gdf = rd.get_data_frame() 388 | 389 | # Create a polygon of size equal to the image dimensions 390 | region_set_polygon = box( 391 | region_image_bounds.minx, 392 | region_image_bounds.maxy, 393 | region_image_bounds.maxx, 394 | region_image_bounds.miny, 395 | ) 396 | 397 | # TODO: Instead of removing detections partially extending out of the boundary, 398 | # try cropping it using gpd.clip() 399 | # Remove detections that extend beyond the image bounds 400 | within_bounds_indices = rd_gdf.within(region_set_polygon) 401 | within_bounds_indices_true = within_bounds_indices[ 402 | within_bounds_indices 403 | ].index 404 | 405 | # Subset the RegionDetections object keeping only the valid indices calculated before 406 | filtered_rd = rd.subset_detections(within_bounds_indices_true) 407 | list_of_filtered_regions.append(filtered_rd) 408 | 409 | list_of_filtered_region_sets.append( 410 | RegionDetectionsSet(list_of_filtered_regions) 411 | ) 412 | 413 | # Update region_idx to point to the image dims of the next rds 414 | region_idx += num_of_regions_in_a_set 415 | 416 | return list_of_filtered_region_sets 417 | -------------------------------------------------------------------------------- /tree_detection_framework/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/preprocessing/__init__.py -------------------------------------------------------------------------------- /tree_detection_framework/preprocessing/derived_geodatasets.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, namedtuple 2 | from pathlib import Path 3 | from typing import Any, List, Optional, Union 4 | 5 | import fiona 6 | import fiona.transform 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import rasterio 10 | import shapely.geometry 11 | import torch 12 | from PIL import Image 13 | from shapely.affinity import affine_transform 14 | from torch.utils.data import DataLoader, Dataset 15 | from torchgeo.datamodules import GeoDataModule 16 | from torchgeo.datasets import ( 17 | IntersectionDataset, 18 | RasterDataset, 19 | VectorDataset, 20 | stack_samples, 21 | ) 22 | from torchgeo.datasets.utils import BoundingBox, array_to_tensor 23 | from torchgeo.samplers import GridGeoSampler, Units 24 | from torchvision import transforms 25 | 26 | from tree_detection_framework.constants import PATH_TYPE 27 | 28 | # Define a namedtuple to store bounds of tiles images from the `CustomImageDataset` 29 | bounding_box = namedtuple("bounding_box", ["minx", "maxx", "miny", "maxy"]) 30 | 31 | 32 | class CustomRasterDataset(RasterDataset): 33 | """ 34 | Custom dataset class for orthomosaic raster images. This class extends the `RasterDataset` from `torchgeo`. 35 | 36 | Attributes: 37 | filename_glob (str): Glob pattern to match files in the directory. 38 | is_image (bool): Indicates that the data being loaded is image data. 39 | separate_files (bool): True if data is stored in a separate file for each band, else False. 40 | """ 41 | 42 | filename_glob: str = "*.tif" # To match all TIFF files 43 | is_image: bool = True 44 | separate_files: bool = False 45 | 46 | 47 | class CustomVectorDataset(VectorDataset): 48 | """ 49 | Custom dataset class for vector data which act as labels for the raster data. This class extends the `VectorDataset` from `torchgeo`. 50 | """ 51 | 52 | def __getitem__(self, query: BoundingBox) -> dict[str, Any]: 53 | """Retrieve image/mask and metadata indexed by query. 54 | This function is largely based on the `__getitem__` method from TorchGeo's `VectorDataset`. 55 | Modifications have been made to include the following keys within the returned dictionary: 56 | 1. 'shapes' as polygons per tile represented in pixel coordinates. 57 | 2. 'bounding_boxes' as bounding box of every detected polygon per tile in pixel coordinates. 58 | 59 | Args: 60 | query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index 61 | 62 | Returns: 63 | sample of image/mask and metadata at that index 64 | 65 | Raises: 66 | IndexError: if query is not found in the index 67 | """ 68 | hits = self.index.intersection(tuple(query), objects=True) 69 | filepaths = [hit.object for hit in hits] 70 | 71 | if not filepaths: 72 | raise IndexError( 73 | f"query: {query} not found in index with bounds: {self.bounds}" 74 | ) 75 | 76 | shapes = [] 77 | for filepath in filepaths: 78 | with fiona.open(filepath) as src: 79 | # We need to know the bounding box of the query in the source CRS 80 | (minx, maxx), (miny, maxy) = fiona.transform.transform( 81 | self.crs.to_dict(), 82 | src.crs, 83 | [query.minx, query.maxx], 84 | [query.miny, query.maxy], 85 | ) 86 | 87 | # Filter geometries to those that intersect with the bounding box 88 | for feature in src.filter(bbox=(minx, miny, maxx, maxy)): 89 | # Warp geometries to requested CRS 90 | shape = fiona.transform.transform_geom( 91 | src.crs, self.crs.to_dict(), feature["geometry"] 92 | ) 93 | label = self.get_label(feature) 94 | shapes.append((shape, label)) 95 | 96 | # Rasterize geometries 97 | width = (query.maxx - query.minx) / self.res 98 | height = (query.maxy - query.miny) / self.res 99 | transform = rasterio.transform.from_bounds( 100 | query.minx, query.miny, query.maxx, query.maxy, width, height 101 | ) 102 | if shapes: 103 | masks = rasterio.features.rasterize( 104 | shapes, out_shape=(round(height), round(width)), transform=transform 105 | ) 106 | else: 107 | # If no features are found in this query, return an empty mask 108 | # with the default fill value and dtype used by rasterize 109 | masks = np.zeros((round(height), round(width)), dtype=np.uint8) 110 | 111 | # Use array_to_tensor since rasterize may return uint16/uint32 arrays. 112 | masks = array_to_tensor(masks) 113 | 114 | masks = masks.to(self.dtype) 115 | 116 | # Beginning of additions made to this function 117 | 118 | # Invert the transform to convert geo coordinates to pixel values 119 | inverse_transform = ~transform 120 | 121 | # Convert `fiona` type shapes to `shapely` shape objects for easier manipulation 122 | shapely_shapes = [(shapely.geometry.shape(sh), i) for sh, i in shapes] 123 | 124 | # Apply the inverse transform to each shapely shape, converting geo coordinates to pixel coordinates 125 | pixel_transformed_shapes = [ 126 | (affine_transform(sh, inverse_transform.to_shapely()), i) 127 | for sh, i in shapely_shapes 128 | ] 129 | 130 | # Convert each polygon to an axis-aligned bounding box of format (x_min, y_min, x_max, y_max) in pixel coordinates 131 | bounding_boxes = [] 132 | for polygon, _ in pixel_transformed_shapes: 133 | x_min, y_min, x_max, y_max = polygon.bounds 134 | bounding_boxes.append([x_min, y_min, x_max, y_max]) 135 | 136 | # Add `shapes` and `bounding_boxes` to the dictionary. 137 | sample = { 138 | "mask": masks, 139 | "crs": self.crs, 140 | "bounds": query, 141 | "shapes": pixel_transformed_shapes, 142 | "bounding_boxes": bounding_boxes, 143 | } 144 | 145 | if self.transforms is not None: 146 | sample = self.transforms(sample) 147 | 148 | return sample 149 | 150 | 151 | class CustomImageDataset(Dataset): 152 | def __init__( 153 | self, 154 | images_dir: Union[PATH_TYPE, List[str]], 155 | chip_size: int, 156 | chip_stride: int, 157 | labels_dir: Optional[List[str]] = None, 158 | ): 159 | """ 160 | Dataset for creating a dataloader from a folder of individual images, with an option to create tiles. 161 | 162 | Args: 163 | images_dir (Union[Path, List[str]]): Path to the folder containing image files, or list of paths to image files. 164 | chip_size (int): Dimension of each image chip (width, height) in pixels. 165 | chip_stride (int): Stride to take while chipping the images (horizontal, vertical) in pixels. 166 | labels_dir (Optional[List[str]]): List of paths to annotation .geojson files corresponding to the images. 167 | Should have same file name as the image. 168 | """ 169 | self.chip_size = chip_size 170 | self.chip_stride = chip_stride 171 | 172 | if not isinstance(images_dir, list): 173 | self.images_dir = Path(images_dir) 174 | # Get a list of all image paths 175 | image_extensions = [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"] 176 | self.image_paths = sorted( 177 | [ 178 | path 179 | for path in self.images_dir.glob("*") 180 | if path.suffix.lower() in image_extensions 181 | ] 182 | ) 183 | 184 | if len(self.image_paths) == 0: 185 | raise ValueError(f"No image files found in {self.images_dir}") 186 | else: 187 | self.image_paths = images_dir 188 | 189 | self.labels_paths = labels_dir 190 | self.tile_metadata = self._get_metadata() 191 | 192 | def _get_metadata(self): 193 | metadata = [] 194 | for img_path in self.image_paths: 195 | # Ensure the label path corresponds to the same file name as the image 196 | label_path = None 197 | if self.labels_paths: 198 | # Match the label file by replacing the image extension with `.geojson` 199 | expected_label_name = img_path.stem + ".geojson" 200 | matching_labels = filter( 201 | lambda label: Path(label).name == expected_label_name, 202 | self.labels_paths, 203 | ) 204 | label_path = next(matching_labels, None) 205 | if label_path is None: 206 | raise ValueError(f"Label file not found for image: {img_path}") 207 | 208 | tile_idx = 0 # A unique tile index value within this image, resets for every new image 209 | with Image.open(img_path) as img: 210 | img_width, img_height = img.size 211 | 212 | # Generate tile coordinates 213 | for y in range(0, img_height, self.chip_stride): 214 | for x in range(0, img_width, self.chip_stride): 215 | # Add metadata for the current tile 216 | metadata.append((tile_idx, img_path, label_path, x, y)) 217 | tile_idx += 1 218 | return metadata 219 | 220 | def __len__(self): 221 | return len(self.tile_metadata) 222 | 223 | def __getitem__(self, idx): 224 | 225 | img_idx, img_path, label_path, x, y = self.tile_metadata[idx] 226 | 227 | with Image.open(img_path) as img: 228 | img = img.convert("RGB") 229 | 230 | # Check if the tile extends beyond the image boundary 231 | tile_width = min(self.chip_size, img.width - x) 232 | tile_height = min(self.chip_size, img.height - y) 233 | 234 | # If the tile fits within the image, return the cropped image 235 | if tile_width == self.chip_size and tile_height == self.chip_size: 236 | tile = img.crop((x, y, x + self.chip_size, y + self.chip_size)) 237 | else: 238 | # Create a white square tile of shape 'chip_size' 239 | tile = Image.new( 240 | "RGB", (self.chip_size, self.chip_size), (255, 255, 255) 241 | ) 242 | 243 | # Crop the image section and paste onto the white image 244 | img_section = img.crop((x, y, x + tile_width, y + tile_height)) 245 | tile.paste(img_section, (0, 0)) 246 | 247 | # Convert to tensor 248 | if not isinstance(tile, torch.Tensor): 249 | tile = transforms.ToTensor()(tile) 250 | 251 | metadata = { 252 | "image_index": img_idx, 253 | "source_image": str(img_path), 254 | "image_bounds": bounding_box( 255 | 0, 256 | float(img.width), 257 | float(img.height), 258 | 0, 259 | ), 260 | } 261 | if self.labels_paths is not None: 262 | metadata["annotations"] = str(label_path) 263 | 264 | return { 265 | "image": tile, 266 | "metadata": metadata, 267 | # Bounds includes bounding box values for the whole tile including white padded region if any 268 | "bounds": bounding_box( 269 | float(x), 270 | float(x + self.chip_size), 271 | float(y + self.chip_size), 272 | float(y), 273 | ), 274 | "crs": None, 275 | } 276 | 277 | @staticmethod 278 | def collate_as_defaultdict(batch): 279 | # Stack images from batch into a single tensor 280 | images = torch.stack([item["image"] for item in batch]) 281 | # Collect metadata as a list 282 | metadata = [item["metadata"] for item in batch] 283 | bounds = [item["bounds"] for item in batch] 284 | crs = [item["crs"] for item in batch] 285 | return defaultdict( 286 | lambda: None, 287 | {"image": images, "metadata": metadata, "bounds": bounds, "crs": crs}, 288 | ) 289 | 290 | 291 | class CustomDataModule(GeoDataModule): 292 | # TODO: Add docstring 293 | def __init__( 294 | self, 295 | output_res: float, 296 | train_raster_path: str, 297 | vector_label_name: str, 298 | train_vector_path: str, 299 | size: int, 300 | stride: int, 301 | batch_size: int = 2, 302 | val_raster_path: Optional[str] = None, 303 | val_vector_path: Optional[str] = None, 304 | test_raster_path: Optional[str] = None, 305 | test_vector_path: Optional[str] = None, 306 | ) -> None: 307 | super().__init__(dataset_class=IntersectionDataset) 308 | self.output_res = output_res 309 | self.vector_label_name = vector_label_name 310 | self.size = size 311 | self.stride = stride 312 | self.batch_size = batch_size 313 | 314 | # Paths for train, val and test dataset 315 | self.train_raster_path = train_raster_path 316 | self.val_raster_path = val_raster_path 317 | self.test_raster_path = test_raster_path 318 | self.train_vector_path = train_vector_path 319 | self.val_vector_path = val_vector_path 320 | self.test_vector_path = test_vector_path 321 | 322 | def create_intersection_dataset( 323 | self, raster_path: str, vector_path: str 324 | ) -> IntersectionDataset: 325 | raster_data = CustomRasterDataset(paths=raster_path, res=self.output_res) 326 | vector_data = CustomVectorDataset( 327 | paths=vector_path, res=self.output_res, label_name=self.vector_label_name 328 | ) 329 | return raster_data & vector_data # IntersectionDataset 330 | 331 | def setup(self, stage=None): 332 | # create the data based on the stage the Trainer is in 333 | if stage == "fit": 334 | self.train_data = self.create_intersection_dataset( 335 | self.train_raster_path, self.train_vector_path 336 | ) 337 | if stage == "validate" or stage == "fit": 338 | self.val_data = self.create_intersection_dataset( 339 | self.val_raster_path, self.val_vector_path 340 | ) 341 | if stage == "test": 342 | self.test_data = self.create_intersection_dataset( 343 | self.test_raster_path, self.test_vector_path 344 | ) 345 | 346 | def train_dataloader(self) -> DataLoader: 347 | sampler = GridGeoSampler(self.train_data, size=self.size, stride=self.stride) 348 | return DataLoader( 349 | self.train_data, 350 | sampler=sampler, 351 | collate_fn=stack_samples, 352 | batch_size=self.batch_size, 353 | ) 354 | 355 | def val_dataloader(self) -> DataLoader: 356 | sampler = GridGeoSampler( 357 | self.val_data, size=self.size, stride=self.stride, units=Units.CRS 358 | ) 359 | return DataLoader( 360 | self.val_data, 361 | sampler=sampler, 362 | collate_fn=stack_samples, 363 | batch_size=self.batch_size, 364 | ) 365 | 366 | def test_dataloader(self) -> DataLoader: 367 | sampler = GridGeoSampler( 368 | self.test_data, size=self.size, stride=self.stride, units=Units.CRS 369 | ) 370 | return DataLoader( 371 | self.test_data, 372 | sampler=sampler, 373 | collate_fn=stack_samples, 374 | batch_size=self.batch_size, 375 | ) 376 | 377 | def on_after_batch_transfer(self, batch, dataloader_idx: int): 378 | return batch 379 | -------------------------------------------------------------------------------- /tree_detection_framework/preprocessing/preprocessing.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import random 4 | from pathlib import Path 5 | from typing import List, Optional, Union 6 | 7 | import matplotlib.pyplot as plt 8 | import pyproj 9 | import rasterio 10 | import shapely 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from torchgeo.datasets import IntersectionDataset, stack_samples, unbind_samples 14 | from torchgeo.samplers import GridGeoSampler, Units 15 | from torchvision.transforms import ToPILImage 16 | 17 | from tree_detection_framework.constants import ARRAY_TYPE, BOUNDARY_TYPE, PATH_TYPE 18 | from tree_detection_framework.preprocessing.derived_geodatasets import ( 19 | CustomImageDataset, 20 | CustomRasterDataset, 21 | CustomVectorDataset, 22 | ) 23 | from tree_detection_framework.utils.geospatial import get_projected_CRS 24 | from tree_detection_framework.utils.raster import plot_from_dataloader 25 | 26 | logging.basicConfig( 27 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 28 | ) 29 | 30 | 31 | def create_spatial_split( 32 | region_to_be_split: BOUNDARY_TYPE, split_fractions: ARRAY_TYPE 33 | ) -> List[shapely.MultiPolygon]: 34 | """Creates non-overlapping spatial splits 35 | 36 | Args: 37 | region_to_be_split (BOUNDARY_TYPE): 38 | A spatial region to be split up. May be defined as a shapely object, geopandas object, 39 | or a path to a geospatial file. In any case, the union of all the elements will be taken. 40 | split_fractions (ARRAY_TYPE): 41 | A sequence of fractions to split the input region into. If they don't sum to 1, the total 42 | wlil be normalized. 43 | 44 | Returns: 45 | List[shapely.MultiPolygon]: 46 | A list of regions representing spatial splits of the input. The area of each one is 47 | controlled by the corresponding element in split_fractions. 48 | 49 | """ 50 | raise NotImplementedError() 51 | 52 | 53 | def create_dataloader( 54 | raster_folder_path: PATH_TYPE, 55 | chip_size: float, 56 | chip_stride: Optional[float] = None, 57 | chip_overlap_percentage: float = None, 58 | use_units_meters: bool = False, 59 | region_of_interest: Optional[BOUNDARY_TYPE] = None, 60 | output_resolution: Optional[float] = None, 61 | output_CRS: Optional[pyproj.CRS] = None, 62 | vector_label_folder_path: Optional[PATH_TYPE] = None, 63 | vector_label_attribute: Optional[str] = None, 64 | batch_size: int = 1, 65 | ) -> DataLoader: 66 | """ 67 | Create a tiled dataloader using torchgeo. Contains raster data data and optionally vector labels 68 | 69 | Args: 70 | raster_folder_path (PATH_TYPE): Path to the folder or raster files 71 | chip_size (float): 72 | Dimension of the chip. May be pixels or meters, based on `use_units_meters`. 73 | chip_stride (Optional[float], optional): 74 | Stride of the chip. May be pixels or meters, based on `use_units_meters`. If used, 75 | `chip_overlap_percentage` should not be set. Defaults to None. 76 | chip_overlap_percentage (Optional[float], optional): 77 | Percent overlap of the chip from 0-100. If used, `chip_stride` should not be set. 78 | Defaults to None. 79 | use_units_meters (bool, optional): 80 | Use units of meters rather than pixels when interpreting the `chip_size` and `chip_stride`. 81 | Defaults to False. 82 | region_of_interest (Optional[BOUNDARY_TYPE], optional): 83 | Only data from this spatial region will be included in the dataloader. Defaults to None. 84 | output_resolution (Optional[float], optional): 85 | Spatial resolution the data in meters/pixel. If un-set, will be the resolution of the 86 | first raster data that is read. Defaults to None. 87 | output_CRS: (Optional[pyproj.CRS], optional): 88 | The coordinate reference system to use for the output data. If un-set, will be the CRS 89 | of the first tile found. Defaults to None. 90 | vector_label_folder_path (Optional[PATH_TYPE], optional): 91 | A folder of geospatial vector files that will be used for the label. If un-set, the 92 | dataloader will not be labeled. Defaults to None. 93 | vector_label_attribute (Optional[str], optional): 94 | Attribute to read from the vector data, such as the class or instance ID. Defaults to None. 95 | batch_size (int, optional): 96 | Number of images to load in a batch. Defaults to 1. 97 | 98 | Returns: 99 | DataLoader: 100 | A dataloader containing tiles from the raster data and optionally corresponding labels 101 | from the vector data. 102 | """ 103 | 104 | # changes: 1. bounding box included in every sample as a df / np array 105 | # 2. TODO: float or uint8 images 106 | # match with the param dict from the model, else error out 107 | # Stores image data 108 | raster_dataset = CustomRasterDataset( 109 | paths=raster_folder_path, res=output_resolution 110 | ) 111 | 112 | # Stores label data 113 | vector_dataset = ( 114 | CustomVectorDataset( 115 | paths=vector_label_folder_path, 116 | res=output_resolution, 117 | label_name=vector_label_attribute, 118 | ) 119 | if vector_label_folder_path is not None 120 | else None 121 | ) 122 | 123 | units = Units.CRS if use_units_meters == True else Units.PIXELS 124 | logging.info(f"Units = {units}") 125 | 126 | if use_units_meters and raster_dataset.crs.is_geographic: 127 | # Reproject the dataset to a meters-based CRS 128 | logging.info("Projecting to meters-based CRS...") 129 | lat, lon = raster_dataset.bounds[2], raster_dataset.bounds[0] 130 | 131 | # Return a new projected CRS value with meters units 132 | projected_crs = get_projected_CRS(lat, lon) 133 | 134 | # Type conversion to rasterio.crs 135 | projected_crs = rasterio.crs.CRS.from_wkt(projected_crs.to_wkt()) 136 | 137 | # Recreating the raster and vector dataset objects with the new CRS value 138 | raster_dataset = CustomRasterDataset( 139 | paths=raster_folder_path, crs=projected_crs 140 | ) 141 | vector_dataset = ( 142 | CustomVectorDataset( 143 | paths=vector_label_folder_path, 144 | crs=projected_crs, 145 | label_name=vector_label_attribute, 146 | ) 147 | if vector_label_folder_path is not None 148 | else None 149 | ) 150 | 151 | # Create an intersection dataset that combines raster and label data if given. Otherwise, proceed with just raster_dataset. 152 | final_dataset = ( 153 | IntersectionDataset(raster_dataset, vector_dataset) 154 | if vector_label_folder_path is not None 155 | else raster_dataset 156 | ) 157 | 158 | if chip_overlap_percentage: 159 | # Calculate `chip_stride` if `chip_overlap_percentage` is provided 160 | chip_stride = chip_size * (1 - chip_overlap_percentage / 100.0) 161 | logging.info(f"Calculated stride based on overlap: {chip_stride}") 162 | 163 | elif chip_stride is None: 164 | raise ValueError( 165 | "Either 'chip_stride' or 'chip_overlap_percentage' must be provided." 166 | ) 167 | 168 | logging.info(f"Stride = {chip_stride}") 169 | 170 | # GridGeoSampler to get contiguous tiles 171 | sampler = GridGeoSampler( 172 | final_dataset, size=chip_size, stride=chip_stride, units=units 173 | ) 174 | dataloader = DataLoader( 175 | final_dataset, batch_size=batch_size, sampler=sampler, collate_fn=stack_samples 176 | ) 177 | 178 | return dataloader 179 | 180 | 181 | def create_image_dataloader( 182 | images_dir: Union[PATH_TYPE, List[str]], 183 | chip_size: int, 184 | chip_stride: Optional[int] = None, 185 | chip_overlap_percentage: Optional[float] = None, 186 | labels_dir: Optional[List[str]] = None, 187 | batch_size: int = 1, 188 | ) -> DataLoader: 189 | """ 190 | Create a dataloader for a folder of normal images (e.g., JPGs), tiling them into smaller patches. 191 | 192 | Args: 193 | images_dir (Union[Path, List[str]]): 194 | Path to the folder containing image files, or list of paths to image files. 195 | chip_size (int): 196 | Size of the tiles (width, height) in pixels. 197 | chip_stride (Optional[int], optional): 198 | Stride of the tiling (horizontal, vertical) in pixels. 199 | chip_overlap_percentage (Optional[float], optional): 200 | Percent overlap of the chip from 0-100. If used, `chip_stride` should not be set. 201 | labels_dir (Optional[List[str]], optional): 202 | List of paths to tree crown label files corresponding to the images. 203 | This will be used as ground truth during evaluation 204 | batch_size (int, optional): 205 | Number of tiles in a batch. Defaults to 1. 206 | 207 | Returns: 208 | DataLoader: A dataloader containing the tiles and associated metadata. 209 | """ 210 | 211 | logging.info("Units set in PIXELS") 212 | 213 | if chip_overlap_percentage: 214 | # Calculate `chip_stride` if `chip_overlap_percentage` is provided 215 | chip_stride = chip_size * (1 - chip_overlap_percentage / 100.0) 216 | chip_stride = int(chip_stride) 217 | logging.info(f"Calculated stride based on overlap: {chip_stride}") 218 | 219 | elif chip_stride is None: 220 | raise ValueError( 221 | "Either 'chip_stride' or 'chip_overlap_percentage' must be provided." 222 | ) 223 | 224 | dataset = CustomImageDataset( 225 | images_dir=images_dir, 226 | chip_size=chip_size, 227 | chip_stride=chip_stride, 228 | labels_dir=labels_dir, 229 | ) 230 | dataloader = DataLoader( 231 | dataset, 232 | batch_size=batch_size, 233 | shuffle=False, 234 | collate_fn=CustomImageDataset.collate_as_defaultdict, 235 | ) 236 | return dataloader 237 | 238 | 239 | def visualize_dataloader(dataloader: DataLoader, n_tiles: int): 240 | """Show samples from the dataloader. 241 | 242 | Args: 243 | dataloader (DataLoader): The dataloader to visualize. 244 | n_tiles (int): The number of randomly-sampled tiles to show. 245 | """ 246 | # Get a random sample of `n_tiles` index values to visualize 247 | tile_indices = random.sample(range(len(dataloader.sampler)), n_tiles) 248 | 249 | # Get a list of all tile bounds from the sampler 250 | list_of_bboxes = list(dataloader.sampler) 251 | 252 | for i in tile_indices: 253 | sample_bbox = list_of_bboxes[i] 254 | 255 | # Get the referenced sample from the dataloader 256 | sample = dataloader.dataset[sample_bbox] 257 | 258 | # Plot the sample image. 259 | plot_from_dataloader(sample) 260 | plt.axis("off") 261 | plt.show() 262 | 263 | 264 | def save_dataloader_contents( 265 | dataloader: DataLoader, 266 | save_folder: PATH_TYPE, 267 | n_tiles: Optional[int] = None, 268 | random_sample: bool = False, 269 | ): 270 | """Save contents of the dataloader to a folder. 271 | 272 | Args: 273 | dataloader (DataLoader): 274 | Dataloader to save the contents of. 275 | save_folder (PATH_TYPE): 276 | Folder to save data to. Will be created if it doesn't exist. 277 | n_tiles (Optional[int], optional): 278 | Number of tiles to save. Whether they are the first tiles or random is controlled by 279 | `random_sample`. If unset, all tiles will be saved. Defaults to None. 280 | random_sample (bool, optional): 281 | If `n_tiles` is set, should the tiles be randomly sampled rather than taken from the 282 | beginning of the dataloader. Defaults to False. 283 | """ 284 | # Create save directory if it doesn't exist 285 | destination_folder = Path(save_folder) 286 | destination_folder.mkdir(parents=True, exist_ok=True) 287 | 288 | transform_to_pil = ToPILImage() 289 | 290 | # TODO: handle batch_size > 1 291 | # Collect all batches from the dataloader 292 | all_batches = list(dataloader) 293 | 294 | # Flatten the list of batches into individual samples 295 | all_samples = [sample for batch in all_batches for sample in unbind_samples(batch)] 296 | 297 | # If `n_tiles` is set, limit the number of tiles to save 298 | if n_tiles is not None: 299 | if random_sample: 300 | # Randomly sample `n_tiles`. If `n_tiles` is greater than available samples, include all samples. 301 | selected_samples = random.sample( 302 | all_samples, min(n_tiles, len(all_samples)) 303 | ) 304 | else: 305 | # Take first `n_tiles` 306 | selected_samples = all_samples[:n_tiles] 307 | else: 308 | selected_samples = all_samples 309 | 310 | # Counter for saved tiles 311 | saved_tiles_count = 0 312 | 313 | # Iterate over the selected samples 314 | for sample in selected_samples: 315 | image = sample["image"] 316 | image_tensor = torch.clamp(image / 255.0, min=0, max=1) 317 | pil_image = transform_to_pil(image_tensor) 318 | 319 | # Save the image tile 320 | pil_image.save(destination_folder / f"tile_{saved_tiles_count}.png") 321 | 322 | # Prepare tile metadata 323 | metadata = { 324 | "crs": sample["crs"].to_string(), 325 | "bounds": list(sample["bounds"]), 326 | } 327 | 328 | # If dataset includes labels, save crown metadata 329 | if isinstance(dataloader.dataset, IntersectionDataset): 330 | shapes = sample["shapes"] 331 | crowns = [ 332 | {"ID": tree_id, "crown": polygon.wkt} for polygon, tree_id in shapes 333 | ] 334 | metadata["crowns"] = crowns 335 | 336 | # Save metadata to a JSON file 337 | with open(destination_folder / f"tile_{saved_tiles_count}.json", "w") as f: 338 | json.dump(metadata, f, indent=4) 339 | 340 | # Increment the saved tile count 341 | saved_tiles_count += 1 342 | 343 | # Stop once the desired number of tiles is saved 344 | if n_tiles is not None and saved_tiles_count >= n_tiles: 345 | break 346 | 347 | print(f"Saved {saved_tiles_count} tiles to {save_folder}") 348 | -------------------------------------------------------------------------------- /tree_detection_framework/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/utils/__init__.py -------------------------------------------------------------------------------- /tree_detection_framework/utils/benchmarking.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import xml.etree.ElementTree as ET 3 | from glob import glob 4 | from pathlib import Path 5 | from typing import List, Union 6 | 7 | import geopandas as gpd 8 | import numpy as np 9 | from shapely.geometry import box 10 | from torch.utils.data import DataLoader 11 | 12 | from tree_detection_framework.constants import PATH_TYPE 13 | from tree_detection_framework.detection.detector import Detector 14 | from tree_detection_framework.evaluation.evaluate import ( 15 | compute_matched_ious, 16 | compute_precision_recall, 17 | ) 18 | from tree_detection_framework.postprocessing.postprocessing import single_region_NMS 19 | from tree_detection_framework.preprocessing.preprocessing import create_image_dataloader 20 | 21 | logging.basicConfig(level=logging.INFO) 22 | 23 | 24 | def get_neon_gt( 25 | images_dir: PATH_TYPE, annotations_dir: PATH_TYPE 26 | ) -> dict[str, dict[str, List[box]]]: 27 | """ 28 | Extract ground truth bounding boxes from NEON XML annotations. 29 | Args: 30 | images_dir (PATH_TYPE): Directory containing image tiles. 31 | annotations_dir (PATH_TYPE): Directory containing XML annotation files. 32 | Returns: 33 | dict: A dictionary mapping image paths to a dictionary with "gt" key containing ground truth boxes. 34 | """ 35 | tiles_to_predict = list(Path(images_dir).glob("*.tif")) 36 | mappings = {} 37 | 38 | for path in tiles_to_predict: 39 | plot_name = path.stem # Get filename without extension 40 | annot_fname = Path(annotations_dir) / f"{plot_name}.xml" 41 | 42 | if not annot_fname.exists(): 43 | continue 44 | 45 | # Load XML file 46 | tree = ET.parse(annot_fname) 47 | root = tree.getroot() 48 | 49 | # Extract bounding boxes 50 | gt_boxes = [] 51 | for obj in root.findall(".//object"): 52 | bndbox = obj.find("bndbox") 53 | if bndbox is not None: 54 | xmin = int(bndbox.find("xmin").text) 55 | ymin = int(bndbox.find("ymin").text) 56 | xmax = int(bndbox.find("xmax").text) 57 | ymax = int(bndbox.find("ymax").text) 58 | gt_boxes.append(box(xmin, ymin, xmax, ymax)) 59 | 60 | # Add the ground truth boxes to the mappings 61 | mappings[str(path)] = {"gt": gt_boxes} 62 | return mappings 63 | 64 | 65 | def get_detectree2_gt(dataloader, tile_size=1000) -> dict[str, dict[str, List[box]]]: 66 | """Extract ground truth bounding boxes from Detectree2 annotations.""" 67 | mappings = {} 68 | for i in dataloader: 69 | img_path = i["metadata"][0]["source_image"] 70 | gt_gdf = gpd.read_file(i["metadata"][0]["annotations"]) 71 | # Perform a vertical flip of the annotations to match the expected format. This is due to 72 | # differences between the i,j and x,y indexing conventions. 73 | gt_gdf.geometry = gt_gdf.transform( 74 | lambda x: np.concatenate((x[:, 0:1], tile_size - x[:, 1:2]), axis=1) 75 | ) 76 | 77 | # Convert each geometry to its axis-aligned bounding box polygon 78 | bounding_boxes = [box(*geom.bounds) for geom in gt_gdf.geometry] 79 | 80 | mappings[img_path] = {"gt": bounding_boxes} 81 | return mappings 82 | 83 | 84 | def get_neon_dataloader(image_paths: List[str]) -> DataLoader: 85 | """Create a dataloader for the NEON dataset.""" 86 | # Create dataloader setting image size as 400x400, the standard size of the NEON dataset. 87 | dataloader = create_image_dataloader( 88 | image_paths, 89 | chip_size=400, 90 | chip_stride=400, 91 | ) 92 | return dataloader 93 | 94 | 95 | def get_detectree2_dataloader( 96 | images_dir: Union[PATH_TYPE, List[PATH_TYPE]], 97 | annotations_dir: Union[PATH_TYPE, List[PATH_TYPE]], 98 | ) -> "DataLoader": 99 | """Create a Detectree2 dataloader from one or more image/annotation directories.""" 100 | 101 | # Extract the list of paths to image and annotation directories 102 | images_dirs = [ 103 | Path(p) for p in (images_dir if isinstance(images_dir, list) else [images_dir]) 104 | ] 105 | annotations_dirs = [ 106 | Path(p) 107 | for p in ( 108 | annotations_dir if isinstance(annotations_dir, list) else [annotations_dir] 109 | ) 110 | ] 111 | 112 | # Collect all image and annotation paths 113 | img_paths = [] 114 | for img_dir in images_dirs: 115 | img_paths.extend(list(img_dir.glob("*"))) 116 | 117 | ann_paths = [] 118 | for ann_dir in annotations_dirs: 119 | ann_paths.extend(list(ann_dir.glob("*"))) 120 | 121 | # Create dataloader using all images 122 | dataloader = create_image_dataloader( 123 | images_dir=img_paths, chip_size=1000, chip_stride=1000, labels_dir=ann_paths 124 | ) 125 | return dataloader 126 | 127 | 128 | def get_benchmark_detections( 129 | dataset_name: str, 130 | images_dir: Union[PATH_TYPE, List[PATH_TYPE]], 131 | annotations_dir: Union[PATH_TYPE, List[PATH_TYPE]], 132 | detectors: dict[str, Detector], 133 | nms_threshold: float = None, 134 | min_confidence: float = 0.5, 135 | nms_on_polygons: bool = False, 136 | ) -> dict[str, dict[str, List[box]]]: 137 | """ 138 | Load ground truth, create dataloader, and run detectors on the images from the benchmark dataset. 139 | Args: 140 | dataset_name (str): Name of the dataset ("neon" or "detectree2"). 141 | images_dir (PATH_TYPE, List[PATH_TYPE]): Directory or list of directories to image files. 142 | annotations_dir (PATH_TYPE, List[PATH_TYPE]): Directory or list of directories to annotation files. 143 | detectors (dict): Dictionary of detector instances to be evaluated. 144 | nms_threshold (float): Non-maximum suppression threshold. 145 | min_confidence (float): Minimum confidence threshold for detections. 146 | nms_on_polygons (bool): Should NMS be run on polygons before converting them to bounding boxes. Defaults to False. 147 | Returns: 148 | dict: A dictionary mapping image paths to a dictionary with detector names and the corresponding output boxes. 149 | """ 150 | if dataset_name == "neon": 151 | mappings = get_neon_gt(images_dir, annotations_dir) 152 | dataloader = get_neon_dataloader(list(mappings.keys())) 153 | 154 | elif dataset_name == "detectree2": 155 | dataloader = get_detectree2_dataloader(images_dir, annotations_dir) 156 | mappings = get_detectree2_gt(dataloader) 157 | 158 | else: 159 | raise ValueError(f"Unknown dataset: {dataset_name}") 160 | 161 | for name, detector in detectors.items(): 162 | # Get predictions from every detector 163 | logging.info(f"Running detector: {name}") 164 | region_detection_sets, filenames, _ = detector.predict_raw_drone_images( 165 | dataloader 166 | ) 167 | 168 | # Add predictions to the mappings so that it looks like: 169 | # {"image_path_1": {"gt": gt_boxes, "detector_name_1": [boxes], ...}, 170 | # "image_path_2": {"gt": gt_boxes, "detector_name_1": [boxes], ...}, ...} 171 | for filename, rds in zip(filenames, region_detection_sets): 172 | if nms_threshold is not None: 173 | 174 | # If we don't want to do polygon NMS and if the detection was originally a polygon, 175 | # instead use the data in the "bbox" column which specifies an axis-aligned bounding box. 176 | if not nms_on_polygons and name in ["detectree2", "sam2"]: 177 | rds = rds.update_geometry_column("bbox") 178 | 179 | rds = single_region_NMS( 180 | rds.get_region_detections(0), 181 | threshold=nms_threshold, 182 | min_confidence=min_confidence, 183 | ) 184 | gdf = rds.get_data_frame() 185 | 186 | # Add the detections to the mappings dictionary 187 | if name == "deepforest": 188 | mappings[filename][name] = list(gdf.geometry) 189 | elif name in ["sam2", "detectree2"]: 190 | mappings[filename][name] = list(gdf["bbox"]) 191 | else: 192 | raise ValueError(f"Unknown detector: {name}") 193 | 194 | return mappings 195 | 196 | 197 | def evaluate_detections(detections_dict: dict[str, dict[str, List[box]]]): 198 | """Step 2: Compute precision and recall for each detector. 199 | Args: 200 | detections_dict (dict): Dictionary mapping image paths to a dictionary 201 | with detector names and the corresponding output boxes. Output of get_neon_detections. 202 | """ 203 | img_paths = list(detections_dict.keys()) 204 | # Get the list of detectors, which are keys of the sub-dictionary. 205 | detector_names = [ 206 | key for key in detections_dict[img_paths[0]].keys() if key != "gt" 207 | ] 208 | logging.info(f"Detectors to be evaluated: {detector_names}") 209 | for detector in detector_names: 210 | all_predictions_P = [] 211 | all_predictions_R = [] 212 | for img in img_paths: 213 | gt_boxes = detections_dict[img]["gt"] 214 | pred_boxes = detections_dict[img][detector] 215 | iou_output = compute_matched_ious(gt_boxes, pred_boxes) 216 | P, R = compute_precision_recall(iou_output, len(gt_boxes), len(pred_boxes)) 217 | all_predictions_P.append(P) 218 | all_predictions_R.append(R) 219 | 220 | P = np.mean(all_predictions_P) 221 | R = np.mean(all_predictions_R) 222 | F1 = (2 * P * R) / (P + R) 223 | print(f"'{detector}': Precision={P}, Recall={R}, F1-Score={F1}") 224 | -------------------------------------------------------------------------------- /tree_detection_framework/utils/detection.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import urllib 4 | 5 | import pandas as pd 6 | from deepforest.utilities import DownloadProgressBar 7 | 8 | 9 | def use_release_df( 10 | save_dir=os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../../data/"), 11 | prebuilt_model="NEON", 12 | check_release=True, 13 | ): 14 | """ 15 | Check the existence of, or download the latest model release from github 16 | Args: 17 | save_dir: Directory to save filepath, default to "data" in deepforest repo 18 | prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk. 19 | check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise. 20 | 21 | Returns: release_tag, output_path (str): path to downloaded model 22 | 23 | """ 24 | os.makedirs(save_dir, exist_ok=True) 25 | 26 | # Naming based on pre-built model 27 | output_path = os.path.join(save_dir, prebuilt_model + ".pt") 28 | 29 | if check_release: 30 | # Find latest github tag release from the DeepLidar repo 31 | _json = json.loads( 32 | urllib.request.urlopen( 33 | urllib.request.Request( 34 | "https://api.github.com/repos/Weecology/DeepForest/releases/latest", 35 | headers={"Accept": "application/vnd.github.v3+json"}, 36 | ) 37 | ).read() 38 | ) 39 | asset = _json["assets"][0] 40 | url = asset["browser_download_url"] 41 | 42 | # Check the release tagged locally 43 | try: 44 | release_txt = pd.read_csv(save_dir + "current_release.csv") 45 | except BaseException: 46 | release_txt = pd.DataFrame({"current_release": [None]}) 47 | 48 | # Download the current release it doesn't exist 49 | if not release_txt.current_release[0] == _json["html_url"]: 50 | 51 | print( 52 | "Downloading model from DeepForest release {}, see {} " 53 | "for details".format(_json["tag_name"], _json["html_url"]) 54 | ) 55 | 56 | with DownloadProgressBar( 57 | unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1] 58 | ) as t: 59 | urllib.request.urlretrieve( 60 | url, filename=output_path, reporthook=t.update_to 61 | ) 62 | 63 | print("Model was downloaded and saved to {}".format(output_path)) 64 | 65 | # record the release tag locally 66 | release_txt = pd.DataFrame({"current_release": [_json["html_url"]]}) 67 | release_txt.to_csv(save_dir + "current_release.csv") 68 | else: 69 | print( 70 | "Model from DeepForest release {} was already downloaded. " 71 | "Loading model from file.".format(_json["html_url"]) 72 | ) 73 | 74 | return _json["html_url"], output_path 75 | else: 76 | try: 77 | release_txt = pd.read_csv(save_dir + "current_release.csv") 78 | except BaseException: 79 | raise ValueError( 80 | "Check release argument is {}, but no release " 81 | "has been previously downloaded".format(check_release) 82 | ) 83 | 84 | return release_txt.current_release[0], output_path 85 | -------------------------------------------------------------------------------- /tree_detection_framework/utils/geometric.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import shapely 4 | from contourpy import contour_generator 5 | 6 | 7 | def get_shapely_transform_from_matrix(matrix_transform: np.ndarray): 8 | """ 9 | Take a matrix transform and convert it into format expected by shapely: [a, b, d, e, xoff, y_off] 10 | 11 | Args: 12 | matrix_transform (np.ndarray): 13 | (2, 3) or (3, 3) 2D transformation matrix such that the matrix-vector product produces 14 | the transformed value. 15 | """ 16 | shapely_transform = [ 17 | matrix_transform[0, 0], 18 | matrix_transform[0, 1], 19 | matrix_transform[1, 0], 20 | matrix_transform[1, 1], 21 | matrix_transform[0, 2], 22 | matrix_transform[1, 2], 23 | ] 24 | return shapely_transform 25 | 26 | 27 | def mask_to_shapely( 28 | mask: np.ndarray, simplify_tolerance: float = 0, backend: str = "contourpy" 29 | ) -> shapely.MultiPolygon: 30 | """ 31 | Convert a binary mask to a Shapely MultiPolygon representing positive regions, 32 | with optional simplification. 33 | 34 | Args: 35 | mask (np.ndarray): A (n, m) A mask with boolean values. 36 | simplify_tolerance (float): Tolerance for simplifying polygons. A value of 0 means no simplification. 37 | backend (str): The backend to use for contour extraction. Choose from "cv2" and "contourpy". Defaults to contourpy. 38 | 39 | Returns: 40 | shapely.MultiPolygon: A MultiPolygon representing the positive regions. 41 | """ 42 | if not np.any(mask): 43 | return shapely.Polygon() # Return an empty Polygon if the mask is empty. 44 | 45 | if backend == "cv2": 46 | # CV2-based approach 47 | contours, _ = cv2.findContours( 48 | mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 49 | ) 50 | 51 | polygons = [] 52 | for contour in contours: 53 | contour = np.squeeze(contour) 54 | # Skip invalid contours 55 | if (contour.ndim != 2) or (contour.shape[0] < 3): 56 | continue 57 | 58 | # Convert the contour to a shapely geometry 59 | shape = shapely.Polygon(contour) 60 | 61 | if isinstance(shape, shapely.MultiPolygon): 62 | # Append all individual polygons 63 | polygons.extend(shape.geoms) 64 | elif isinstance(shape, shapely.Polygon): 65 | # Append the polygon 66 | polygons.append(shape) 67 | 68 | # Combine all polygons into a MultiPolygon 69 | multipolygon = shapely.MultiPolygon(polygons) 70 | 71 | if simplify_tolerance > 0: 72 | multipolygon = multipolygon.simplify(simplify_tolerance) 73 | 74 | return multipolygon 75 | 76 | elif backend == "contourpy": 77 | # ContourPy-based approach 78 | filled = contour_generator( 79 | z=mask, fill_type="ChunkCombinedOffsetOffset" 80 | ).filled(0.5, np.inf) 81 | chunk_polygons = [ 82 | shapely.from_ragged_array( 83 | shapely.GeometryType.POLYGON, points, (offsets, outer_offsets) 84 | ) 85 | for points, offsets, outer_offsets in zip(*filled) 86 | ] 87 | 88 | multipolygon = shapely.unary_union(chunk_polygons) 89 | 90 | # Simplify the resulting MultiPolygon if needed 91 | if simplify_tolerance > 0: 92 | multipolygon = multipolygon.simplify(simplify_tolerance) 93 | 94 | return multipolygon 95 | 96 | else: 97 | raise ValueError( 98 | f"Unsupported backend: {backend}. Choose 'cv2' or 'contourpy'." 99 | ) 100 | -------------------------------------------------------------------------------- /tree_detection_framework/utils/geospatial.py: -------------------------------------------------------------------------------- 1 | import pyproj 2 | 3 | 4 | def get_projected_CRS( 5 | lat: float, lon: float, assume_western_hem: bool = True 6 | ) -> pyproj.CRS: 7 | """ 8 | Returns a projected Coordinate Reference System (CRS) based on latitude and longitude. 9 | 10 | Args: 11 | lat (float): Latitude in degrees. 12 | lon (float): Longitude in degrees. 13 | assume_western_hem (bool): Assumes the longitude is in the Western Hemisphere. Defaults to True. 14 | 15 | Returns: 16 | pyproj.CRS: The projected CRS corresponding to the UTM zone for the given latitude and longitude. 17 | """ 18 | 19 | if assume_western_hem and lon > 0: 20 | lon = -lon 21 | epgs_code = 32700 - round((45 + lat) / 90) * 100 + round((183 + lon) / 6) 22 | crs = pyproj.CRS.from_epsg(epgs_code) 23 | return crs 24 | -------------------------------------------------------------------------------- /tree_detection_framework/utils/raster.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from typing import Optional 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pyproj 7 | import rasterio 8 | import rasterio as rio 9 | import rasterio.plot 10 | from rasterio.warp import Resampling, calculate_default_transform, reproject 11 | 12 | from tree_detection_framework.constants import PATH_TYPE 13 | 14 | 15 | # Copied from https://github.com/open-forest-observatory/geograypher/blob/2900ede9a00ac8bdce22c43e4abb6d74876390f6/geograypher/utils/geospatial.py#L333 16 | def load_downsampled_raster_data(dataset_filename: PATH_TYPE, downsample_factor: float): 17 | """Load a raster file spatially downsampled 18 | 19 | Args: 20 | dataset (PATH_TYPE): Path to the raster 21 | downsample_factor (float): Downsample factor of 10 means that pixels are 10 times larger 22 | 23 | Returns: 24 | np.array: The downsampled array in the rasterio (c, h, w) convention 25 | rio.DatasetReader: The reader with the transform updated 26 | rio.Transform: The updated transform 27 | """ 28 | # Open the dataset handler. Note that this doesn't read into memory. 29 | dataset = rio.open(dataset_filename) 30 | 31 | # resample data to target shape 32 | data = dataset.read( 33 | out_shape=( 34 | dataset.count, 35 | int(dataset.height / downsample_factor), 36 | int(dataset.width / downsample_factor), 37 | ), 38 | resampling=rio.enums.Resampling.bilinear, 39 | ) 40 | 41 | # scale image transform 42 | updated_transform = dataset.transform * dataset.transform.scale( 43 | (dataset.width / data.shape[-1]), (dataset.height / data.shape[-2]) 44 | ) 45 | # Return the data and the transform 46 | return data, dataset, updated_transform 47 | 48 | 49 | def reproject_raster( 50 | input_file: PATH_TYPE, output_file: PATH_TYPE, dst_crs: PATH_TYPE 51 | ) -> PATH_TYPE: 52 | """_summary_ 53 | 54 | Args: 55 | input_file (PATH_TYPE): _description_ 56 | output_file (PATH_TYPE): _description_ 57 | dst_crs (PATH_TYPE): _description_ 58 | 59 | Returns: 60 | PATH_TYPE: _description_ 61 | """ 62 | # Taken from here: https://rasterio.readthedocs.io/en/latest/topics/reproject.html 63 | # Open the source raster 64 | with rasterio.open(input_file, "r") as src: 65 | # If it is in the desired CRS, then return the input file since this is a no-op 66 | if dst_crs == src.crs: 67 | return input_file 68 | 69 | # Calculate the parameters of the transform for the data in the new CRS 70 | transform, width, height = calculate_default_transform( 71 | src.crs, dst_crs, src.width, src.height, *src.bounds 72 | ) 73 | kwargs = src.meta.copy() 74 | # Create updated metadata for this new file 75 | kwargs.update( 76 | {"crs": dst_crs, "transform": transform, "width": width, "height": height} 77 | ) 78 | 79 | # Open the output file 80 | with rasterio.open(output_file, "w", **kwargs) as dst: 81 | # Perform reprojection per band 82 | for i in range(1, src.count + 1): 83 | reproject( 84 | source=rasterio.band(src, i), 85 | destination=rasterio.band(dst, i), 86 | src_transform=src.transform, 87 | src_crs=src.crs, 88 | dst_transform=transform, 89 | dst_crs=dst_crs, 90 | resampling=Resampling.nearest, 91 | ) 92 | # Return the output file path that the reprojected raster was written to. 93 | return output_file 94 | 95 | 96 | def show_raster( 97 | raster_file_path: PATH_TYPE, 98 | downsample_factor: float = 10.0, 99 | plt_ax: Optional[plt.axes] = None, 100 | CRS: Optional[pyproj.CRS] = None, 101 | ): 102 | """Show a raster, optionally downsampling or reprojecting it 103 | 104 | Args: 105 | raster_file_path (PATH_TYPE): 106 | Path to the raster file 107 | downsample_factor (float): 108 | How much to downsample the raster before visualization, this makes it faster and consume 109 | less memory 110 | plt_ax (Optional[plt.axes], optional): 111 | Axes to plot on, otherwise the current ones are used. Defaults to None. 112 | CRS (Optional[pyproj.CRS], optional): 113 | The CRS to reproject the data to if set. Defaults to None. 114 | """ 115 | # Check if the file is georeferenced 116 | with rasterio.open(raster_file_path) as src: 117 | if src.crs is None: 118 | crs_available = False 119 | else: 120 | crs_available = True 121 | 122 | # Handle cases where no CRS is available 123 | if not crs_available and CRS is not None: 124 | print(f"Warning: No CRS found in the raster. Proceeding without reprojection.") 125 | CRS = None 126 | 127 | # If the CRS is set, ensure the data matches it 128 | if CRS is not None: 129 | # Create a temporary file to write to 130 | temp_output_filename = tempfile.NamedTemporaryFile(suffix=".tif") 131 | # Get the name of this file 132 | temp_name = temp_output_filename.name 133 | # Reproject the raster. If the CRS was the same as requested, the original raster path will 134 | # be returned. Otherwise, the reprojected raster will be written to the temp file and that 135 | # path will be returned. 136 | raster_file_path = reproject_raster( 137 | input_file=raster_file_path, output_file=temp_name, dst_crs=CRS 138 | ) 139 | # Load the downsampled image 140 | img, _, transform = load_downsampled_raster_data( 141 | raster_file_path, downsample_factor=downsample_factor 142 | ) 143 | # Plot the image 144 | rio.plot.show(source=img, transform=transform, ax=plt_ax) 145 | 146 | 147 | def plot_from_dataloader(sample): 148 | """ 149 | Plots an image from the dataset. 150 | 151 | Args: 152 | sample (dict): A dictionary containing the tile to plot. The 'image' key should have a tensor of shape (C, H, W). 153 | 154 | Returns: 155 | matplotlib.figure.Figure: A figure containing the plotted image. 156 | """ 157 | # Reorder and rescale the image 158 | image = sample["image"].permute(1, 2, 0).numpy() 159 | 160 | # Create the figure to plot on and return 161 | fig, ax = plt.subplots() 162 | # Plot differently based on the number of channels 163 | n_channels = image.shape[2] 164 | 165 | # Show with a colorbar and default matplotlib mapping for scalar data 166 | if n_channels == 1: 167 | cbar = ax.imshow(image) 168 | plt.colorbar(cbar, ax=ax) 169 | # Plot as RGB(A) 170 | elif n_channels in (3, 4): 171 | if image.dtype != np.uint8: 172 | # See if this should be interpreted as data 0-255, even if it's float data 173 | max_val = np.max(image) 174 | # If the values are greater than 1, assume it's supposed to be unsigned int8 175 | # and cast to that so it's properly shown 176 | if max_val > 1: 177 | image = image.astype(np.uint8) 178 | # Plot the image 179 | ax.imshow(image) 180 | else: 181 | raise ValueError(f"Cannot plot image with {n_channels} channels") 182 | 183 | return fig 184 | --------------------------------------------------------------------------------