├── .github
└── workflows
│ ├── black.yml
│ └── isort.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── checkpoints
└── .gitignore
├── data
└── .gitignore
├── examples
├── predict_detections.ipynb
└── predict_detections_sam2.ipynb
├── poetry.lock
├── pyproject.toml
├── sandbox
├── .gitignore
├── detection
│ ├── detect_raw_drone_images.ipynb
│ ├── detectree2_detections.ipynb
│ ├── geometric_detections.ipynb
│ ├── plot_detections.ipynb
│ └── random_detections.ipynb
├── evaluation
│ ├── README.md
│ ├── dtree2_benchmark.ipynb
│ ├── neon_benchmark.ipynb
│ ├── sam2_dtree2_benchmark.ipynb
│ └── sam2_neon_benchmark.ipynb
├── postprocessing
│ ├── nms_experiments.ipynb
│ └── specify_postprocessing_in_detector.ipynb
└── preprocessing
│ ├── chip_ortho.py
│ ├── deepforest_train_and_predict.ipynb
│ ├── load_raw_drone_images.ipynb
│ └── test-dataloader-with-dtree2.ipynb
└── tree_detection_framework
├── __init__.py
├── constants.py
├── detection
├── SAM2_detector.py
├── __init__.py
├── detector.py
├── models.py
└── region_detections.py
├── entrypoints
├── __init__.py
├── generate_predictions.py
└── tile_data.py
├── evaluation
├── __init__.py
└── evaluate.py
├── postprocessing
├── __init__.py
└── postprocessing.py
├── preprocessing
├── __init__.py
├── derived_geodatasets.py
└── preprocessing.py
└── utils
├── __init__.py
├── benchmarking.py
├── detection.py
├── geometric.py
├── geospatial.py
└── raster.py
/.github/workflows/black.yml:
--------------------------------------------------------------------------------
1 | name: Black Linter
2 |
3 | # This allows Black be called from the isort workflow
4 | on:
5 | workflow_call:
6 |
7 | jobs:
8 | black-linter:
9 | name: Run Black
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Determine Branch Name
14 | run: |
15 | if [ "${{ github.event_name }}" = "pull_request" ]; then
16 | echo "BRANCH_NAME=${{ github.head_ref }}" >> $GITHUB_ENV
17 | else
18 | BRANCH_NAME=$(echo ${GITHUB_REF#refs/heads/})
19 | echo "BRANCH_NAME=$BRANCH_NAME" >> $GITHUB_ENV
20 | fi
21 |
22 | - name: Checkout code
23 | uses: actions/checkout@v3
24 | with:
25 | ref: ${{ env.BRANCH_NAME }}
26 |
27 | - name: Pull latest changes
28 | run: git pull origin ${{ env.BRANCH_NAME }}
29 |
30 | - name: Run Black formatter
31 | uses: psf/black@stable
32 | with:
33 | options: "--verbose"
34 | src: "."
35 | jupyter: true
36 |
37 | - name: Push changes
38 | uses: stefanzweifel/git-auto-commit-action@v4
39 | with:
40 | commit_message: Apply black formatting changes
--------------------------------------------------------------------------------
/.github/workflows/isort.yml:
--------------------------------------------------------------------------------
1 | name: Code Formatter
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 |
9 | jobs:
10 | isort:
11 | name: Run Isort
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - name: Determine Branch Name
16 | run: |
17 | if [ "${{ github.event_name }}" = "pull_request" ]; then
18 | echo "BRANCH_NAME=${{ github.head_ref }}" >> $GITHUB_ENV
19 | else
20 | BRANCH_NAME=$(echo ${GITHUB_REF#refs/heads/})
21 | echo "BRANCH_NAME=$BRANCH_NAME" >> $GITHUB_ENV
22 | fi
23 |
24 | - name: Checkout code
25 | uses: actions/checkout@v3
26 | with:
27 | ref: ${{ env.BRANCH_NAME }}
28 |
29 | - name: Set up Python
30 | uses: actions/setup-python@v2
31 | with:
32 | python-version: 3.9
33 |
34 | - name: Run Isort
35 | run: pip install isort==5.13.2 && isort --profile black .
36 |
37 | - name: Push changes
38 | uses: stefanzweifel/git-auto-commit-action@v5
39 | with:
40 | commit_message: Apply isort formatting changes
41 |
42 | # Ensure Black runs after Isort has completed
43 | black:
44 | needs: isort
45 | uses: ./.github/workflows/black.yml
46 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Formatting and style
2 | This repository uses `isort` for import management and `black` for general formatting. Both can be installed using `pip` or `conda` and on a Linux system, they can also be installed at the system level using `apt`. For example, using `pip`, you'd simply do
3 | ```
4 | pip install isort==5.13.2
5 | pip install black==24.1.0
6 | ```
7 | We use the default arguments so the tools can be executed by simply pointing them at all files in the repository. From the root directory of the project, you should execute the following:
8 | ```
9 | isort .
10 | black .
11 | ```
12 | Note that it's important that `isort` is run first, because it doesn't produce a format that's consistent with `black`.
13 |
14 | If you push changes to main or create a pull request, please be aware that Github Actions will trigger a workflow that runs `isort` and `black` on the code. This will take a few seconds to run and the workflow may automatically push formatting changes to the repository. To ensure your local repository is up to date with the remote repository, wait for a few seconds and pull the latest changes.
15 |
16 | # Branch naming
17 | If you are adding a branch to this repository, please use the following convention: `{feature, bugfix, hotfix, release, docs}/{developer initials}/{short-hyphenated-description}`. For example, `docs/DR/add-branch-naming-convention` for this change. For a description of the prefixes, please see [here](https://medium.com/@abhay.pixolo/naming-conventions-for-git-branches-a-cheatsheet-8549feca2534).
18 |
19 | # Docstrings
20 | For documentation, we use the [Google](https://github.com/NilsJPWerner/autoDocstring/blob/HEAD/docs/google.md) format. I personally use [VSCode autoDocstring](https://marketplace.visualstudio.com/items?itemName=njpwerner.autodocstring) plugin for templating. Keeping the docstrings up-to-date is essential because we automatically integrate them into our documentation. If the docstrings are outdated, the docstrings shown on the documentation will also outdated.
21 |
22 | # Type hints
23 | Typing hints, as introduced by [PEP 484](https://peps.python.org/pep-0484/), are strongly encouraged. This helps provide additional documentation and allows some code editors to make additional autocompletes.
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, Open Forest Observatory
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # tree-detection-framework
2 | This project has three main goals:
3 | * Enable tree detection on realistic-scale, geospatial raster data with minimal boilerplate, using existing (external) tree detection/segmentation models
4 | * Facilitate direct comparison of multiple algorithms
5 | * Rely on modern libraries and software best practice for a robust, performant, and modular tool
6 |
7 | This project does not, itself, provide tree detection/segmentation algorithms (with the exception of a geometric algorithm). Instead, it provides a standardized interface for performing training, inference, and evaluation using existing tree detection models and algorithms. The project currently supports the external computer vision models DeepForest, Dectree2, and SAM2, as well as a geometric canopy height model segmentor implemented within TDF. Support for other external models can be added by implementing a new `Detector` class.
8 |
9 | We use the `torchgeo` package to perform data loading and standardization using standard geospatial input formats. This library allows us to generate chips on the fly of a given size, stride, and spatial resolution. Training and inference is done with modular detectors that can be based on existing models and algorithms. We have preliminary support for using `PyTorch Lightning` to minimize boilerplate around model training and prediction. Region-level nonmax-suppression (NMS) is done using the `PolyGoneNMS` library which is efficient for large images. Visualization and saving of the predictions is done using `geopandas`, a common library for geospatial data.
10 |
11 | This project is under active development by the [Open Forest Observatory](https://openforestobservatory.org/). We welcome contributions and suggestions for improvement.
12 |
13 | ## Tree detection models supported
14 | TDF currently supports the following tree detection/segmentation algorithms.
15 |
16 | ### DeepForest
17 | - [Github](https://github.com/weecology/DeepForest)
18 | - Uses RGB input data. Predicts tree crowns with rectangular bounding boxes.
19 | - Provides a RetinaNet model trained on a large number of semi-supervised tree crown annotations and a smaller set of manual annotations.
20 | - Trained using data from the US only but representing diverse regions. The model has been applied on data from outside the US successfully.
21 |
22 | ### Detectree2
23 | - [Github](https://github.com/PatBall1/detectree2)
24 | - Uses RGB input data. Predicts tree crowns with polygon boundaries.
25 | - Provides a Mask R-CNN model trained on manually labeled tree crowns from four sites.
26 | - Trained using data from tropical forests.
27 |
28 | ### Segment Anything Model 2 (SAM2)
29 | - [Github](https://github.com/facebookresearch/sam2)
30 | - Uses RGB input data. Predicts objects with polygon boundaries.
31 | - Utilizes the Segment Anything Model (SAM 2.1 Hiera Large) checkpoint with tuned parameters for mask generation optimized for tree crown delineation.
32 | - Does not rely on supervised training for tree-specific data but generalizes well due to SAM's zero-shot nature; however, non-tree objects are also detected and included in predictions.
33 |
34 | ### Geometric Detector
35 | - Implementation of the variable window filter algorithm of [Popescu and Wynne
36 | (2004)](https://www.ingentaconnect.com/content/asprs/pers/2004/00000070/00000005/art00003) for
37 | tree top detection, combined with the algorithm of [Silva et al.
38 | (2016)](https://www.tandfonline.com/doi/full/10.1080/07038992.2016.1196582#abstract) for crown
39 | segmentation.
40 | - Uses canopy height model (CHM) input data. Predicts tree crowns with polygon boundaries.
41 | - This is a learning-free tree detection algorithm. It is the one algorithm that is implemented within TDF as opposed to relying on an existing external model/algorithm.
42 |
43 | ## Software architecture
44 | The `tree-detection-framework` is organized into modular components to facilitate extension including integration of additional detection models. The main components are:
45 |
46 | 1. **`preprocessing.py`**
47 | The `create_dataloader()` method accepts single/multiple orthomosaic inputs. Alternatively,
48 | `create_image_datalaoder()` accepts a folder containing raw drone imagery. The methods tile the
49 | input images based on user-specified parameters such as tile size, stride, and resolution and
50 | return a PyTorch-compatible dataloader for inference.
51 | 2. **`Detector` Base Class**
52 | All detectors in the framework (e.g., DeepForestDetector, Detectree2Detector) inherit from the
53 | `Detector` base class. The base class defines the core logic for generating predictions and
54 | geospatially referencing image tiles, while model-specific detectors translate the inputs to the
55 | format expected by the respective model. This design allows all detectors to plug into the same
56 | pipeline with minimal code changes.
57 | 3. **`RegionDetectionsSet` and `RegionDetections`**
58 | These classes standardize model outputs. A `RegionDetectionsSet` is a collection of `RegionDetections`, where each `RegionDetections` object represents the detections in a single image tile. This abstraction allows postprocessing components to operate uniformly across different detectors. These outputs can be saved out as `.gpkg` or `.geojson` files.
59 | 4. **`postprocessing.py`**
60 | Impelments a set of postprocessing functions for cleaning the detections by Non-Maximum Suppression(NMS), polygon hole suppression, tile boundary suppression, and removing out of bounds detections. Most of these methods operate on standardized output types (`RegionDetections` / `RegionDetectionsSet`).
61 |
62 | ## Install
63 | Some of the dependencies are managed by a tool called [Poetry](https://python-poetry.org/). I've found
64 | easiest to install this using the "official installer" option as follows. Note that this should be run
65 | in the base conda environment or with no environment active.
66 | ```
67 | curl -sSL https://install.python-poetry.org | python3 -
68 | ```
69 | Now create and activate a conda environment for the dependencies of this project.
70 | ```
71 | conda create -n tree-detection-framework python=3.10 -y
72 | conda activate tree-detection-framework
73 | ```
74 |
75 | Now, from the root directory of the project, run the following command. Note that on Jetstream2, you
76 | may need to run this in a graphical session and respond to a keyring popup menu.
77 | ```
78 | poetry install
79 | ```
80 | Finally, choose to either install the Detectron2 or SAM2 detection framework.
81 |
82 | **Detectron2:**
83 | The Detectron2 library is not compatible with `poetry` so must be installed directly with pip
84 | ```
85 | # https://detectron2.readthedocs.io/en/latest/tutorials/install.html#build-detectron2-from-source
86 | pip install git+https://github.com/facebookresearch/detectron2.git
87 | ```
88 | Download the detectree2 checkpoint weights.
89 | ```
90 | cd checkpoints
91 | mkdir detectree2
92 | cd detectree2
93 | wget https://zenodo.org/records/10522461/files/230103_randresize_full.pth
94 | ```
95 | **SAM2:**
96 | Clone the SAM2 repository and install the necessary config files
97 | ```
98 | git clone https://github.com/facebookresearch/sam2.git && cd sam2
99 |
100 | pip install -e .
101 | ```
102 | And download the associated checkpoints
103 | ```
104 | cd checkpoints && \
105 | ./download_ckpts.sh && \
106 | cd ..
107 | ```
108 | And move into this repo
109 | ```
110 | mv checkpoints ../tree-detection-framework
111 | ```
112 |
113 |
114 | ## Use
115 | The module code is in the `tree_detection_framework` folder. Once installed using the `poetry`
116 | command above, this code can be imported into scripts or notebooks under the name
117 | `tree_detection_framework` the same as you would for any other library.
118 |
119 | ## Examples
120 | To begin with, you can access example geospatial data
121 | [here](https://ucdavis.box.com/v/tdf-example-data), which should be downloaded and placed in the `data` folder at the top level of this project. Our goal is to have high-quality,
122 | up-to-date examples in the `examples` folder. We also have work-in-progress or one-off code in
123 | `sandbox`, which still may provide some insight but is not guaranteed to be current or generalizable.
124 | Finally, the `tree_detection_framework/entrypoints` folder has command line scripts that can be run
125 | to complete tasks.
126 |
127 | ## Evaluation and benchmark with NEON
128 | Download the NEON dataset files and save the annotations and RGB folders under a new directory in the `data` folder.
129 | ```
130 | wget -O annotations.zip "https://zenodo.org/records/5914554/files/annotations.zip?download=1"
131 | unzip annotations.zip
132 | wget -O evaluation.zip "https://zenodo.org/records/5914554/files/evaluation.zip?download=1"
133 | unzip -j evaluation.zip "evaluation/RGB/*" -d RGB
134 | rm annotations.zip
135 | rm evaluation.zip
136 | ```
137 | Follow the steps in `tree-detection-framework/sandbox/evaluation/neon_benchmark.ipynb` for detectors `DeepForest` & `Detectree2`, and `tree-detection-framework/sandbox/evaluation/sam2_neon_benchmark.ipynb` to use `SAM2`.
138 |
139 | ## Evaluation and benchmark with Detectree2 datasets
140 | 1. Download the dataset. There are two ways to get the dataset:
141 | Download the site-specific .tif (for orthomosaic) and .gpkg (for ground truth polygons) files from https://zenodo.org/records/8136161. Then, follow steps in https://github.com/PatBall1/detectree2/blob/master/notebooks/colab/tilingJB.ipynb to do the the tiling.
142 | (OR)
143 | Download our pre-tiled dataset from https://ucdavis.box.com/s/thjmaane9d38opw1bhnyxrsrtt90j37m
144 | 3. Add the tiled dataset folder to the `data` folder in this repo.
145 | 4. For benchmark and evaluation see steps in `tree-detection-framework/sandbox/evaluation/dtree2_benchmark.ipynb` and `tree-detection-framework/sandbox/evaluation/sam2_dtree2_benchmark.ipynb`
146 |
--------------------------------------------------------------------------------
/checkpoints/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "tree-detection-framework"
3 | version = "0.1.0"
4 | description = "Framework for detecting trees in remote sensing data"
5 | authors = ["David Russell "]
6 | license = "BSD-3"
7 | readme = "README.md"
8 | packages = [{include = "tree_detection_framework"}]
9 |
10 | [tool.poetry.dependencies]
11 | python = ">=3.10,<3.12"
12 | torchgeo = "^0.6.0"
13 | deepforest = "^1.3.3"
14 | ipykernel = "^6.29.5"
15 | polygonenms-ofo-fork = "^0.1.0"
16 |
17 |
18 | [build-system]
19 | requires = ["poetry-core"]
20 | build-backend = "poetry.core.masonry.api"
21 |
--------------------------------------------------------------------------------
/sandbox/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/sandbox/.gitignore
--------------------------------------------------------------------------------
/sandbox/evaluation/README.md:
--------------------------------------------------------------------------------
1 | # Zero-Shot Tree Detection and Segmentation from Aerial Forest Imagery Experiments
2 | These notebooks provide a reimplementation of experiments presented in this [paper](https://ml-for-rs.github.io/iclr2025/camera_ready/papers/3.pdf) presented at the ([ML4RS25](https://ml-for-rs.github.io/iclr2025/)) workshop. These experiments differ from the initial experiments in several notable ways, leading to different results.
3 | - A subtlety in the original implementation of Detectree2 meant that in the original experiments, inference on the NEON datasets occured with a flipped channel ordering, issue [here](https://github.com/PatBall1/detectree2/issues/197). Fixing this issue lead to better performance.
4 | - The original experiments first converted the polygon predictions to axis-aligned bounding boxes prior to running non-max suppression. In the current experiments, this was changed to running NMS on the polygon representation prior to converting to boxes. This change keeps more detections, leading to higher precision and lower recall than the original experiments.
5 | - Aside from the polygon vs. box NMS consideration, SAM2 results differ for reasons that are not understood.
6 |
7 | The data for these experiments can be downloaded following the instructions for the [NEON](https://github.com/open-forest-observatory/tree-detection-framework?tab=readme-ov-file#evaluation-and-benchmark-with-neon) and [Detectree2](https://github.com/open-forest-observatory/tree-detection-framework?tab=readme-ov-file#evaluation-and-benchmark-with-detectree2-datasets) datasets. You will need to create two seperate environments, one with the Detectree2 dependencies installed and the other with SAM2 dependencies installed following the instructions [here](https://github.com/open-forest-observatory/tree-detection-framework?tab=readme-ov-file#install). Then run the notebooks using the appropriate environment for the prediction model being used.
8 |
--------------------------------------------------------------------------------
/sandbox/preprocessing/chip_ortho.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import random
5 | from pathlib import Path
6 | from typing import Any, Dict, Optional
7 |
8 | import fiona
9 | import fiona.transform
10 | import geopandas as gpd
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import pyproj
14 | import rasterio
15 | import shapely.geometry
16 | import torch
17 | from shapely.affinity import affine_transform
18 | from shapely.geometry import box
19 | from torch.utils.data import DataLoader
20 | from torchgeo.datasets import (
21 | IntersectionDataset,
22 | RasterDataset,
23 | VectorDataset,
24 | stack_samples,
25 | unbind_samples,
26 | )
27 | from torchgeo.datasets.utils import BoundingBox, array_to_tensor
28 | from torchgeo.samplers import GridGeoSampler, Units
29 | from torchvision.transforms import ToPILImage
30 |
31 | # Set up logging configuration
32 | logging.basicConfig(
33 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
34 | )
35 |
36 |
37 | class CustomRasterDataset(RasterDataset):
38 | """
39 | Custom dataset class for orthomosaic raster images. This class extends the `RasterDataset` from `torchgeo`.
40 |
41 | Attributes:
42 | filename_glob (str): Glob pattern to match files in the directory.
43 | is_image (bool): Indicates that the data being loaded is image data.
44 | separate_files (bool): True if data is stored in a separate file for each band, else False.
45 | """
46 |
47 | filename_glob: str = "*.tif" # To match all TIFF files
48 | is_image: bool = True
49 | separate_files: bool = False
50 |
51 |
52 | class CustomVectorDataset(VectorDataset):
53 | """
54 | Custom dataset class for vector data which act as labels for the raster data. This class extends the `VectorDataset` from `torchgeo`.
55 | """
56 |
57 | def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
58 | """Retrieve image/mask and metadata indexed by query.
59 | This function is largely based on the `__getitem__` method from torchgeo's `VectorDataset`, with custom modifications for this implementation.
60 |
61 | Args:
62 | query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
63 |
64 | Returns:
65 | sample of image/mask and metadata at that index
66 |
67 | Raises:
68 | IndexError: if query is not found in the index
69 | """
70 | hits = self.index.intersection(tuple(query), objects=True)
71 | filepaths = [hit.object for hit in hits]
72 |
73 | if not filepaths:
74 | raise IndexError(
75 | f"query: {query} not found in index with bounds: {self.bounds}"
76 | )
77 |
78 | shapes = []
79 | for filepath in filepaths:
80 | with fiona.open(filepath) as src:
81 | # We need to know the bounding box of the query in the source CRS
82 | (minx, maxx), (miny, maxy) = fiona.transform.transform(
83 | self.crs.to_dict(),
84 | src.crs,
85 | [query.minx, query.maxx],
86 | [query.miny, query.maxy],
87 | )
88 |
89 | # Filter geometries to those that intersect with the bounding box
90 | for feature in src.filter(bbox=(minx, miny, maxx, maxy)):
91 | # Warp geometries to requested CRS
92 | shape = fiona.transform.transform_geom(
93 | src.crs, self.crs.to_dict(), feature["geometry"]
94 | )
95 | label = self.get_label(feature)
96 | shapes.append((shape, label))
97 |
98 | # Rasterize geometries
99 | width = (query.maxx - query.minx) / self.res
100 | height = (query.maxy - query.miny) / self.res
101 | transform = rasterio.transform.from_bounds(
102 | query.minx, query.miny, query.maxx, query.maxy, width, height
103 | )
104 | if shapes:
105 | masks = rasterio.features.rasterize(
106 | shapes, out_shape=(round(height), round(width)), transform=transform
107 | )
108 | else:
109 | # If no features are found in this query, return an empty mask
110 | # with the default fill value and dtype used by rasterize
111 | masks = np.zeros((round(height), round(width)), dtype=np.uint8)
112 |
113 | # Use array_to_tensor since rasterize may return uint16/uint32 arrays.
114 | masks = array_to_tensor(masks)
115 |
116 | masks = masks.to(self.dtype)
117 |
118 | # Beginning of additions made to this function
119 |
120 | # Invert the transform to convert geo coordinates to pixel values
121 | inverse_transform = ~transform
122 |
123 | # Convert `fiona` type shapes to `shapely` shape objects for easier manipulation
124 | shapely_shapes = [(shapely.geometry.shape(sh), i) for sh, i in shapes]
125 |
126 | # Apply the inverse transform to each shapely shape, converting geo coordinates to pixel coordinates
127 | pixel_transformed_shapes = [
128 | (affine_transform(sh, inverse_transform.to_shapely()), i)
129 | for sh, i in shapely_shapes
130 | ]
131 |
132 | # Add 'shapes' containing polygons and corresponding ID values
133 | sample = {
134 | "mask": masks,
135 | "crs": self.crs,
136 | "bounds": query,
137 | "shapes": pixel_transformed_shapes,
138 | }
139 |
140 | if self.transforms is not None:
141 | sample = self.transforms(sample)
142 |
143 | return sample
144 |
145 |
146 | def chip_orthomosaics(
147 | raster_path: str,
148 | size: float,
149 | vector_path: Optional[str] = None,
150 | id_column_name: str = "treeID",
151 | stride: Optional[float] = None,
152 | overlap_percent: Optional[float] = None,
153 | res: Optional[float] = None,
154 | use_units_meters: bool = False,
155 | save_dir: Optional[str] = None,
156 | visualize_n: Optional[int] = None,
157 | ) -> DataLoader:
158 | """
159 | Splits an orthomosaic image into smaller tiles with optional reprojection to a meters-based CRS. Tiles can be saved to a directory and visualized.
160 |
161 | Args:
162 | raster_path (str): Path to the folder containing the orthomosaic files.
163 | size (float): Tile size in units of pixels or meters, depending on `use_units_meters`.
164 | vector_path (str, optional): Path to the folder containing the vector data files.
165 | id_column_name (str): Column name in the vector dataframe containing IDs for the tree polygons. Defaults to 'treeID'.
166 | stride (float, optional): The distance between the start of one tile and the next in pixels or meters.
167 | overlap (float, optional): Percentage overlap between consecutive tiles (0-100%). Used to calculate stride if provided.
168 | res (float, optional): Resolution of the dataset in units of the CRS (if not specified, defaults to the resolution of the first image).
169 | use_units_meters (bool, optional): Whether to use meters instead of pixels for tile size and stride.
170 | save_dir (str, optional): Directory where the tiles and metadata should be saved.
171 | visualize_n (int, optional): Number of randomly selected tiles to visualize.
172 |
173 | Returns:
174 | A dataloader with chipped orthomosaic tiles.
175 |
176 | Raises:
177 | ValueError: If neither `stride` nor `overlap` are provided.
178 | """
179 |
180 | # Stores image data
181 | raster_dataset = CustomRasterDataset(paths=raster_path, res=res)
182 |
183 | # Stores label data (hardcoded label_name for now)
184 | vector_dataset = (
185 | CustomVectorDataset(paths=vector_path, res=res, label_name=id_column_name)
186 | if vector_path is not None
187 | else None
188 | )
189 |
190 | units = Units.CRS if use_units_meters == True else Units.PIXELS
191 | logging.info(f"Units = {units}")
192 |
193 | if use_units_meters and raster_dataset.crs.is_geographic:
194 | # Reproject the dataset to a meters-based CRS
195 | logging.info("Projecting to meters-based CRS...")
196 | lat, lon = raster_dataset.bounds[2], raster_dataset.bounds[0]
197 |
198 | # Return a new projected CRS value with meters units
199 | projected_crs = get_projected_CRS(lat, lon)
200 |
201 | # Type conversion to rasterio.crs
202 | projected_crs = rasterio.crs.CRS.from_wkt(projected_crs.to_wkt())
203 |
204 | # Recreating the raster and vector dataset objects with the new CRS value
205 | raster_dataset = CustomRasterDataset(paths=raster_path, crs=projected_crs)
206 | vector_dataset = (
207 | CustomVectorDataset(
208 | paths=vector_path, crs=projected_crs, label_name=id_column_name
209 | )
210 | if vector_path
211 | else None
212 | )
213 |
214 | # Create an intersection dataset that combines raster and label data if given. Otherwise, proceed with just raster_dataset.
215 | final_dataset = (
216 | IntersectionDataset(raster_dataset, vector_dataset)
217 | if vector_path is not None
218 | else raster_dataset
219 | )
220 |
221 | # Calculate stride if overlap is provided
222 | if overlap_percent:
223 | stride = size * (1 - overlap_percent / 100.0)
224 | logging.info(f"Calculated stride based on overlap: {stride}")
225 | elif stride is None:
226 | raise ValueError("Either 'stride' or 'overlap' must be provided.")
227 | logging.info(f"Stride = {stride}")
228 |
229 | # GridGeoSampler to get contiguous tiles
230 | sampler = GridGeoSampler(final_dataset, size=size, stride=stride, units=units)
231 | dataloader = DataLoader(final_dataset, sampler=sampler, collate_fn=stack_samples)
232 |
233 | if visualize_n:
234 | # Randomly pick indices for visualizing tiles if visualize_n is specified
235 | visualize_indices = random.sample(range(len(sampler)), visualize_n)
236 |
237 | for i in visualize_indices:
238 | plot(get_sample_from_index(raster_dataset, sampler, i))
239 | plt.axis("off")
240 | plt.show()
241 |
242 | if save_dir:
243 | # Creates save directory if it doesn't exist
244 | save_path = Path(save_dir)
245 | save_path.mkdir(parents=True, exist_ok=True)
246 |
247 | transform_to_pil = ToPILImage()
248 | for i, batch in enumerate(dataloader):
249 | sample = unbind_samples(batch)[0]
250 |
251 | image = sample["image"]
252 | image_tensor = torch.clamp(image / 255.0, min=0, max=1)
253 | pil_image = transform_to_pil(image_tensor)
254 | pil_image.save(Path(save_dir) / f"tile_{i}.png")
255 |
256 | # Prepare to save tile metadata
257 | metadata = {
258 | "crs": sample["crs"].to_string(),
259 | "bounds": list(sample["bounds"]),
260 | }
261 |
262 | if vector_path:
263 | # Extract shapes (polygons and tree IDs)
264 | shapes = sample["shapes"]
265 |
266 | crowns = [
267 | {"ID": tree_id, "crown": polygon.wkt} for polygon, tree_id in shapes
268 | ]
269 |
270 | # Add crowns to the metadata
271 | metadata["crowns"] = crowns
272 |
273 | # Save tile metadata to a json file
274 | with open(Path(save_dir) / f"tile_{i}.json", "w") as f:
275 | json.dump(metadata, f, indent=4)
276 |
277 | logging.info(f"Saved {i + 1} tiles to {save_dir}")
278 |
279 | return dataloader
280 |
281 |
282 | # Helper functions
283 |
284 |
285 | def get_sample_from_index(
286 | dataset: CustomRasterDataset, sampler: GridGeoSampler, index: int
287 | ) -> Dict:
288 | # Access the specific index from the sampler containing bounding boxes
289 | sample_indices = list(sampler)
290 | sample_idx = sample_indices[index]
291 |
292 | # Get the sample from the dataset using this index
293 | sample = dataset[sample_idx]
294 | return sample
295 |
296 |
297 | def plot(sample: Dict) -> plt.Figure:
298 | image = sample["image"].permute(1, 2, 0)
299 | image = image.byte().numpy()
300 | fig, ax = plt.subplots()
301 | ax.imshow(image)
302 | return fig
303 |
304 |
305 | def get_projected_CRS(
306 | lat: float, lon: float, assume_western_hem: bool = True
307 | ) -> pyproj.CRS:
308 | if assume_western_hem and lon > 0:
309 | lon = -lon
310 | epgs_code = 32700 - round((45 + lat) / 90) * 100 + round((183 + lon) / 6)
311 | crs = pyproj.CRS.from_epsg(epgs_code)
312 | return crs
313 |
314 |
315 | def parse_args() -> argparse.Namespace:
316 | parser = argparse.ArgumentParser(description="Chipping orthomosaic images")
317 | parser.add_argument(
318 | "--raster-path",
319 | type=str,
320 | required=True,
321 | help="Path to folder containing single or multiple orthomosaic images.",
322 | )
323 | parser.add_argument(
324 | "--vector-path",
325 | type=str,
326 | required=False,
327 | help="Path to folder containing single or multiple vector datafiles.",
328 | )
329 | parser.add_argument(
330 | "--id-column-name",
331 | type=str,
332 | default="treeID",
333 | help="Column name in the vector dataframe containing IDs for the tree polygons. Defaults to 'treeID'.",
334 | )
335 | parser.add_argument(
336 | "--res",
337 | type=float,
338 | required=False,
339 | help="Resolution of the dataset in units of CRS (defaults to the resolution of the first file found)",
340 | )
341 | parser.add_argument(
342 | "--size",
343 | type=float,
344 | required=True,
345 | help="Single value used for height and width dim",
346 | )
347 | parser.add_argument(
348 | "--stride",
349 | type=float,
350 | required=False,
351 | help="Distance to skip between each patch",
352 | )
353 | parser.add_argument(
354 | "--overlap-percent",
355 | type=float,
356 | required=False,
357 | help="Percentage overlap between the tiles (0-100%)",
358 | )
359 | parser.add_argument(
360 | "--use-units-meters",
361 | action="store_true",
362 | help="Whether to set units for tile size and stide as meters",
363 | )
364 | parser.add_argument(
365 | "--save-dir", type=str, required=False, help="Directory to save chips"
366 | )
367 | parser.add_argument(
368 | "--visualize-n", type=int, required=False, help="Number of tiles to visualize"
369 | )
370 |
371 | args = parser.parse_args()
372 | return args
373 |
374 |
375 | if __name__ == "__main__":
376 | args = parse_args()
377 | chip_orthomosaics(**args.__dict__)
378 |
--------------------------------------------------------------------------------
/tree_detection_framework/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/__init__.py
--------------------------------------------------------------------------------
/tree_detection_framework/constants.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Union
3 |
4 | import geopandas as gpd
5 | import numpy.typing
6 | import shapely
7 | import torch
8 |
9 | PATH_TYPE = Union[str, Path]
10 | BOUNDARY_TYPE = Union[
11 | PATH_TYPE, shapely.Polygon, shapely.MultiPolygon, gpd.GeoDataFrame, gpd.GeoSeries
12 | ]
13 | ARRAY_TYPE = numpy.typing.ArrayLike
14 |
15 | DATA_FOLDER = Path(Path(__file__).parent, "..", "data").resolve()
16 | CHECKPOINTS_FOLDER = Path(Path(__file__).parent, "..", "checkpoints").resolve()
17 |
18 | DEFAULT_DEVICE = (
19 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
20 | )
21 |
--------------------------------------------------------------------------------
/tree_detection_framework/detection/SAM2_detector.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | from shapely.geometry import box
5 |
6 | from tree_detection_framework.constants import CHECKPOINTS_FOLDER, DEFAULT_DEVICE
7 | from tree_detection_framework.detection.detector import Detector
8 | from tree_detection_framework.utils.geometric import mask_to_shapely
9 |
10 | try:
11 | from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
12 | from sam2.build_sam import build_sam2
13 |
14 | SAM2_AVAILABLE = True
15 | except ImportError:
16 | SAM2_AVAILABLE = False
17 | raise ImportError(
18 | "SAM2 is not installed. Please install it using the instructions in the README."
19 | )
20 |
21 |
22 | # follow README for download instructions
23 | class SAMV2Detector(Detector):
24 |
25 | def __init__(
26 | self,
27 | device=DEFAULT_DEVICE,
28 | sam2_checkpoint=Path(CHECKPOINTS_FOLDER, "sam2.1_hiera_large.pt"),
29 | model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
30 | postprocessors=None,
31 | ):
32 | """
33 | Create a SAM2 detector.
34 | Args:
35 | device (torch.device): Device to run the model on.
36 | sam2_checkpoint (Path): Path to the SAM2 checkpoint.
37 | model_cfg (str): Path to the SAM2 model config.
38 | postprocessors (list, optional): See docstring for Detector class. Defaults to None.
39 | """
40 | super().__init__(postprocessors=postprocessors)
41 |
42 | self.device = device
43 |
44 | self.sam2 = build_sam2(
45 | model_cfg, sam2_checkpoint, device=self.device, apply_postprocessing=False
46 | )
47 | # Create the mask generator with the optimal set of parameters found by
48 | # Michelle Chen & Jane Wu based on qualitative experiments
49 | self.mask_generator = SAM2AutomaticMaskGenerator(
50 | model=self.sam2,
51 | points_per_side=64,
52 | points_per_batch=128,
53 | pred_iou_thresh=0.7,
54 | stability_score_thresh=0.92,
55 | stability_score_offset=0.7,
56 | crop_n_layers=1,
57 | box_nms_thresh=0.7,
58 | crop_n_points_downscale_factor=2,
59 | min_mask_region_area=25.0,
60 | use_m2m=True,
61 | )
62 |
63 | def call_predict(self, batch):
64 | """
65 | Args:
66 | batch (Tensor): 4 dims Tensor with the first dimension having number of images in the batch
67 |
68 | Returns:
69 | masks List[List[Dict]]: list of dictionaries for each mask in the batch
70 | """
71 |
72 | with torch.no_grad():
73 | masks = []
74 | for original_image in batch:
75 | if original_image.shape[0] < 3:
76 | raise ValueError("Original image has less than 3 channels")
77 |
78 | original_image = original_image.permute(1, 2, 0)
79 | # If the pixels are in [0, 255] range, convert to [0, 1] range
80 | if original_image.max() > 1:
81 | original_image = original_image.byte().numpy()
82 | else:
83 | original_image = original_image.numpy()
84 | rgb_image = original_image[:, :, :3]
85 | mask = self.mask_generator.generate(
86 | rgb_image
87 | ) # model expects rgb 0-255 range (h, w, 3)
88 | # FUTURE TODO: Support batched predictions
89 | masks.append(mask)
90 |
91 | return masks
92 |
93 | def predict_batch(self, batch):
94 | """
95 | Get predictions for a batch of images.
96 |
97 | Args:
98 | batch (defaultDict): A batch from the dataloader
99 |
100 | Returns:
101 | all_geometries (List[List[shapely.MultiPolygon]]):
102 | A list of predictions one per image in the batch. The predictions for each image
103 | are a list of shapely objects.
104 | all_data_dicts (Union[None, List[dict]]):
105 | Predicted scores and classes
106 | """
107 | images = batch["image"]
108 |
109 | # computational bottleneck
110 | batch_preds = self.call_predict(images)
111 |
112 | # To store all predicted polygons
113 | all_geometries = []
114 | # To store other related information such as scores and labels
115 | all_data_dicts = []
116 |
117 | # Iterate through predictions for each tile in the batch
118 | for pred in batch_preds:
119 |
120 | # Get the Instances object
121 | segmentations = [dic["segmentation"].astype(float) for dic in pred]
122 |
123 | # Convert each mask to a shapely multipolygon
124 | shapely_objects = [
125 | mask_to_shapely(pred_mask) for pred_mask in segmentations
126 | ]
127 |
128 | all_geometries.append(shapely_objects)
129 |
130 | # Compute axis-aligned minimum area bounding box as Polygon objects
131 | bounding_boxes = [box(*polygon.bounds) for polygon in shapely_objects]
132 |
133 | # Get prediction scores
134 | scores = [dic["stability_score"] for dic in pred]
135 | all_data_dicts.append({"score": scores, "bbox": bounding_boxes})
136 |
137 | return all_geometries, all_data_dicts
138 |
--------------------------------------------------------------------------------
/tree_detection_framework/detection/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/detection/__init__.py
--------------------------------------------------------------------------------
/tree_detection_framework/detection/models.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any, Dict, List, Optional
3 |
4 | import lightning
5 | import torch
6 | import torchvision
7 | from deepforest import main as deepforest_main
8 | from torch import Tensor, optim
9 | from torchvision.models.detection.retinanet import (
10 | AnchorGenerator,
11 | RetinaNet,
12 | RetinaNet_ResNet50_FPN_Weights,
13 | )
14 |
15 | from tree_detection_framework.utils.detection import use_release_df
16 |
17 | try:
18 | from detectron2 import model_zoo
19 | from detectron2.config import get_cfg
20 |
21 | DETECTRON2_AVAILABLE = True
22 | except ImportError:
23 | DETECTRON2_AVAILABLE = False
24 |
25 |
26 | class RetinaNetModel:
27 | """A backbone class for DeepForest"""
28 |
29 | def __init__(self, param_dict):
30 | self.param_dict = param_dict
31 |
32 | def load_backbone(self):
33 | """A torch vision retinanet model"""
34 | backbone = torchvision.models.detection.retinanet_resnet50_fpn(
35 | weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1
36 | )
37 |
38 | return backbone
39 |
40 | def create_anchor_generator(
41 | self, sizes=((8, 16, 32, 64, 128, 256, 400),), aspect_ratios=((0.5, 1.0, 2.0),)
42 | ):
43 | """Create anchor box generator as a function of sizes and aspect ratios"""
44 | anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios)
45 |
46 | return anchor_generator
47 |
48 | def create_model(self):
49 | """Create a retinanet model
50 |
51 | Returns:
52 | model: a pytorch nn module
53 | """
54 | resnet = self.load_backbone()
55 | backbone = resnet.backbone
56 |
57 | model = RetinaNet(backbone=backbone, num_classes=self.param_dict["num_classes"])
58 | # TODO: do we want to set model.nms_thresh and model.score_thresh?
59 |
60 | return model
61 |
62 |
63 | class DeepForestModule(lightning.LightningModule):
64 | def __init__(
65 | self,
66 | use_hugging_face_weights: bool = True,
67 | param_dict: Optional[Dict[str, Any]] = {},
68 | ):
69 | """_summary_
70 |
71 | Args:
72 | use_hugging_face_weights (bool, optional):
73 | Should the model and weights be donwloaded from the Hugging Face repository. If
74 | True, setting param_dict will result in an error. Defaults to True.
75 | param_dict (Optional[Dict[str, Any]], optional):
76 | Configuration parameters to provide finer control over the model. Cannot be used
77 | with use_hugging_face_weights=True. Defaults to {}.
78 |
79 | Raises:
80 | ValueError: If both use_hugging_face_weights and param_dict are set
81 | """
82 | super().__init__()
83 | self.param_dict = param_dict
84 |
85 | # Determine how to obtain the weights
86 | if use_hugging_face_weights:
87 | # Error if the user tried to use both options
88 | if len(param_dict) > 0:
89 | raise ValueError(
90 | "Setting the `param_dict` and `use_hugging_face_weights`=True are mutually exclusive. Please choose one option."
91 | )
92 | # Create the architecture
93 | model = deepforest_main.deepforest()
94 | # Load a pretrained tree detection model from Hugging Face
95 | model.load_model(model_name="weecology/deepforest-tree", revision="main")
96 | # The definition of the model here contains additional metrics that do not match what we
97 | # need. Take only the prediction model portion.
98 | self.model = model.model
99 |
100 | else:
101 | if param_dict["backbone"] == "retinanet":
102 | retinanet = RetinaNetModel(param_dict)
103 | else:
104 | raise ValueError("Only 'retinanet' backbone is currently supported.")
105 |
106 | self.model = retinanet.create_model()
107 | self.use_release()
108 |
109 | def use_release(self, check_release=True):
110 | """Use the latest DeepForest model release from github and load model.
111 | Optionally download if release doesn't exist.
112 | Args:
113 | check_release (logical): whether to check github for a model recent release.
114 | In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded.
115 | If no model has been downloaded an error will raise.
116 | """
117 | # Download latest model from github release
118 | release_tag, self.release_state_dict = use_release_df(
119 | check_release=check_release
120 | )
121 | self.model.load_state_dict(torch.load(self.release_state_dict))
122 |
123 | # load saved model and tag release
124 | self.__release_version__ = release_tag
125 | print("Loading pre-built model: {}".format(release_tag))
126 |
127 | def forward(
128 | self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
129 | ) -> Dict[str, Tensor]:
130 | """Calls the model's forward method.
131 | Args:
132 | images (list[Tensor]): Images to be processed
133 | targets (list[Dict[Tensor]]): Ground-truth boxes present in the image (optional)
134 |
135 | Returns:
136 | result (list[BoxList] or dict[Tensor]):
137 | The output from the model.
138 | During training, it returns a dict[Tensor] which contains the losses.
139 | """
140 | # Move the data to the same device as the model
141 | images = images.to(self.device)
142 | return self.model.forward(images, targets=targets) # Model specific forward
143 |
144 | def training_step(self, batch):
145 | # Ensure model is in train mode
146 | self.model.train()
147 | device = next(self.model.parameters()).device
148 |
149 | # Image is expected to be a list of tensors, each of shape [C, H, W] in 0-1 range.
150 | image_batch = (batch["image"][:, :3, :, :] / 255.0).to(device)
151 | image_batch_list = [image for image in image_batch]
152 |
153 | # To store every image's target - a dictionary containing `boxes` and `labels`
154 | targets = []
155 | for tile in batch["bounding_boxes"]:
156 | # Convert from list to FloatTensor[N, 4]
157 | boxes_tensor = torch.tensor(tile, dtype=torch.float32).to(device)
158 | # Need to remove boxes that go out-of-bounds. Has negative values.
159 | valid_mask = (boxes_tensor >= 0).all(dim=1)
160 | filtered_boxes_tensor = boxes_tensor[valid_mask]
161 | # Create a label tensor. Single class for now.
162 | class_labels = torch.zeros(
163 | filtered_boxes_tensor.shape[0], dtype=torch.int64
164 | ).to(device)
165 | # Dictionary for the tile
166 | d = {"boxes": filtered_boxes_tensor, "labels": class_labels}
167 | targets.append(d)
168 |
169 | loss_dict = self.forward(image_batch_list, targets=targets)
170 |
171 | final_loss = sum([loss for loss in loss_dict.values()])
172 | print("loss: ", final_loss)
173 | return final_loss
174 |
175 | def configure_optimizers(self):
176 | # similar to the one in deepforest
177 | optimizer = optim.SGD(
178 | self.model.parameters(), lr=self.param_dict["train"]["lr"], momentum=0.9
179 | )
180 |
181 | # TODO: Setup lr_scheduler
182 | # TODO: Return 'optimizer', 'lr_scheduler', 'monitor' when validation data is set
183 |
184 | return optimizer
185 |
186 |
187 | class Detectree2Module:
188 | def __init__(self, param_dict: Optional[Dict[str, Any]] = None):
189 | if DETECTRON2_AVAILABLE is False:
190 | raise ImportError(
191 | "detectron2 is not installed. Please install it to use this module."
192 | )
193 | super().__init__()
194 | # If param_dict is not provided, ensure it is an empty dictionary
195 | self.param_dict = param_dict or {}
196 | self.cfg = self.setup_cfg(**self.param_dict)
197 |
198 | def setup_cfg(
199 | self,
200 | base_model: str = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",
201 | trains=("trees_train",),
202 | tests=("trees_val",),
203 | update_model=None,
204 | workers=2,
205 | ims_per_batch=2,
206 | gamma=0.1,
207 | backbone_freeze=3,
208 | warm_iter=120,
209 | momentum=0.9,
210 | batch_size_per_im=1024,
211 | base_lr=0.0003389,
212 | weight_decay=0.001,
213 | max_iter=1000,
214 | num_classes=1,
215 | eval_period=100,
216 | out_dir="./train_outputs",
217 | resize=True,
218 | ):
219 | """Set up config object.
220 | Args:
221 | base_model: base pre-trained model from detectron2 model_zoo
222 | trains: names of registered data to use for training
223 | tests: names of registered data to use for evaluating models
224 | update_model: updated pre-trained model from detectree2 model_garden
225 | workers: number of workers for dataloader
226 | ims_per_batch: number of images per batch
227 | gamma: gamma for learning rate scheduler
228 | backbone_freeze: backbone layer to freeze
229 | warm_iter: number of iterations for warmup
230 | momentum: momentum for optimizer
231 | batch_size_per_im: batch size per image
232 | base_lr: base learning rate
233 | weight_decay: weight decay for optimizer
234 | max_iter: maximum number of iterations
235 | num_classes: number of classes
236 | eval_period: number of iterations between evaluations
237 | out_dir: directory to save outputs
238 | resize: whether to resize input images
239 | """
240 | # Initialize configuration
241 | cfg = get_cfg()
242 | cfg.merge_from_file(model_zoo.get_config_file(base_model))
243 |
244 | # Assign values, prioritizing those in param_dict
245 | cfg.DATASETS.TRAIN = self.param_dict.get("trains", trains)
246 | cfg.DATASETS.TEST = self.param_dict.get("tests", tests)
247 | cfg.DATALOADER.NUM_WORKERS = self.param_dict.get("workers", workers)
248 | cfg.SOLVER.IMS_PER_BATCH = self.param_dict.get("ims_per_batch", ims_per_batch)
249 | cfg.SOLVER.GAMMA = self.param_dict.get("gamma", gamma)
250 | cfg.MODEL.BACKBONE.FREEZE_AT = self.param_dict.get(
251 | "backbone_freeze", backbone_freeze
252 | )
253 | cfg.SOLVER.WARMUP_ITERS = self.param_dict.get("warm_iter", warm_iter)
254 | cfg.SOLVER.MOMENTUM = self.param_dict.get("momentum", momentum)
255 | cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE = self.param_dict.get(
256 | "batch_size_per_im", batch_size_per_im
257 | )
258 | cfg.SOLVER.WEIGHT_DECAY = self.param_dict.get("weight_decay", weight_decay)
259 | cfg.SOLVER.BASE_LR = self.param_dict.get("base_lr", base_lr)
260 | cfg.OUTPUT_DIR = self.param_dict.get("out_dir", out_dir)
261 | cfg.SOLVER.MAX_ITER = self.param_dict.get("max_iter", max_iter)
262 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = self.param_dict.get(
263 | "num_classes", num_classes
264 | )
265 | cfg.TEST.EVAL_PERIOD = self.param_dict.get("eval_period", eval_period)
266 | cfg.RESIZE = self.param_dict.get("resize", resize)
267 | cfg.INPUT.MIN_SIZE_TRAIN = 1000
268 |
269 | # Create output directory if it doesn't exist
270 | Path(cfg.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
271 |
272 | # Set model weights
273 | cfg.MODEL.WEIGHTS = self.param_dict.get(
274 | "update_model", update_model
275 | ) or model_zoo.get_checkpoint_url(base_model)
276 |
277 | return cfg
278 |
279 |
280 | # future TODO: add module configs for sam2, currently implemented for default configs
281 |
--------------------------------------------------------------------------------
/tree_detection_framework/detection/region_detections.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from pathlib import Path
3 | from typing import List, Optional, Union
4 |
5 | import geopandas as gpd
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import pandas as pd
9 | import pyproj
10 | import rasterio.plot
11 | import rasterio.transform
12 | import shapely
13 | from shapely.affinity import affine_transform
14 |
15 | from tree_detection_framework.constants import PATH_TYPE
16 | from tree_detection_framework.utils.raster import show_raster
17 |
18 |
19 | def plot_detections(
20 | data_frame: gpd.GeoDataFrame,
21 | bounds: gpd.GeoSeries,
22 | CRS: Optional[pyproj.CRS] = None,
23 | plt_ax: Optional[plt.axes] = None,
24 | plt_show: bool = True,
25 | visualization_column: Optional[str] = None,
26 | bounds_color: Optional[Union[str, np.array, pd.Series]] = None,
27 | detection_kwargs: dict = {},
28 | bounds_kwargs: dict = {},
29 | raster_file: Optional[PATH_TYPE] = None,
30 | raster_vis_downsample: float = 10.0,
31 | ) -> plt.axes:
32 | """Plot the detections and the bounds of the region
33 |
34 | Args:
35 | data_frame: (gpd.GeoDataFrame):
36 | The data representing the detections
37 | bounds: (gpd.GeoSeries):
38 | The spatial bounds of the predicted region
39 | CRS (Optional[pyproj.CRS], optional):
40 | What CRS to use. Defaults to None.
41 | as_pixels (bool, optional):
42 | Whether to display in pixel coordinates. Defaults to False.
43 | plt_ax (Optional[plt.axes], optional):
44 | A pyplot axes to plot on. If not provided, one will be created. Defaults to None.
45 | plt_show (bool, optional):
46 | Whether to plot the result or just return it. Defaults to True.
47 | visualization_column (Optional[str], optional):
48 | Which column to visualize from the detections dataframe. Defaults to None.
49 | bounds_color (Optional[Union[str, np.array, pd.Series]], optional):
50 | The color to plot the bounds. Must be accepted by the gpd.plot color argument.
51 | Defaults to None.
52 | detection_kwargs (dict, optional):
53 | Additional keyword arguments to pass to the .plot method for the detections.
54 | Defaults to {}.
55 | bounds_kwargs (dict, optional):
56 | Additional keyword arguments to pass to the .plot method for the bounds.
57 | Defaults to {}.
58 | raster_file (Optional[PATH_TYPE], optional):
59 | A path to a raster file to visualize the detections over if provided. Defaults to None.
60 | raster_vis_downsample (float, optional):
61 | The raster file is downsampled by this fraction before visualization to avoid
62 | excessive memory use or plotting time. Defaults to 10.0.
63 |
64 | Returns:
65 | plt.axes: The axes that were plotted on
66 | """
67 |
68 | # If no axes are provided, create new ones
69 | if plt_ax is None:
70 | _, plt_ax = plt.subplots()
71 |
72 | # Show the raster if provided
73 | if raster_file is not None:
74 | show_raster(
75 | raster_file_path=raster_file,
76 | downsample_factor=raster_vis_downsample,
77 | plt_ax=plt_ax,
78 | CRS=CRS,
79 | )
80 |
81 | # Plot the detections dataframe and the bounds on the same axes
82 | if "facecolor" not in detection_kwargs:
83 | # Plot with transperent faces unless requested
84 | detection_kwargs["facecolor"] = "none"
85 |
86 | data_frame.plot(
87 | ax=plt_ax, column=visualization_column, **detection_kwargs, legend=True
88 | )
89 | # Use the .boundary attribute to plot just the border. This works since it's a geoseries,
90 | # not a geodataframe
91 | bounds.boundary.plot(ax=plt_ax, color=bounds_color, **bounds_kwargs)
92 |
93 | # Show if requested
94 | if plt_show:
95 | plt.show()
96 |
97 | # Return the axes in case they need to be used later
98 | return plt_ax
99 |
100 |
101 | class RegionDetections:
102 | detections: gpd.GeoDataFrame
103 | pixel_to_CRS_transform: rasterio.transform.AffineTransformer
104 | prediction_bounds_in_CRS: Union[shapely.Polygon, shapely.MultiPolygon, None]
105 |
106 | def __init__(
107 | self,
108 | detection_geometries: List[shapely.Geometry] | str,
109 | data: Union[dict, pd.DataFrame] = {},
110 | input_in_pixels: bool = True,
111 | CRS: Optional[Union[pyproj.CRS, rasterio.CRS]] = None,
112 | pixel_to_CRS_transform: Optional[rasterio.transform.AffineTransformer] = None,
113 | pixel_prediction_bounds: Optional[
114 | shapely.Polygon | shapely.MultiPolygon
115 | ] = None,
116 | geospatial_prediction_bounds: Optional[
117 | shapely.Polygon | shapely.MultiPolygon
118 | ] = None,
119 | ):
120 | """Create a region detections object
121 |
122 | Args:
123 | detection_geometries (List[shapely.Geometry] | str | None):
124 | A list of shapely geometries for each detection. Alternatively, can be a string
125 | represting a key in data providing the same, or None if that key is named "geometry".
126 | The coordinates can either be provided in pixel coordinates or in the coordinates
127 | of a CRS. input_in_pixels should be set accordingly.
128 | data (Optional[dict | pd.DataFrame], optional):
129 | A dictionary mapping from str names for an attribute to a list of values for that
130 | attribute, one value per detection. Or a pandas dataframe. Passed to the data
131 | argument of gpd.GeoDataFrame. Defaults to {}.
132 | input_in_pixels (bool, optional):
133 | Whether the detection_geometries should be interpreted in pixels or geospatial
134 | coordinates.
135 | CRS (Optional[pyproj.CRS], optional):
136 | A coordinate reference system to interpret the data in. If input_in_pixels is False,
137 | then the input data will be interpreted as values in this CRS. If input_in_pixels is
138 | True and CRS is None, then the data will be interpreted as pixel coordinates with no
139 | georeferencing information. If input_in_pixels is True and CRS is not None, then
140 | the data will be attempted to be geometrically transformed into the CRS using either
141 | pixel_to_CRS_transform if set, or the relationship between the pixel and geospatial
142 | bounds of the region. Defaults to None.
143 | pixel_to_CRS_transform (Optional[rasterio.transform.AffineTransformer], optional):
144 | An affine transformation mapping from the pixel coordinates to those of the CRS.
145 | Only meaningful if `CRS` is set as well. Defaults to None.
146 | pixel_prediction_bounds (Optional[shapely.Polygon | shapely.MultiPolygon], optional):
147 | The pixel bounds of the region that predictions were generated for. For example, a
148 | square starting at (0, 0) and extending to the size in pixels of the tile.
149 | Defaults to None.
150 | geospatial_prediction_bounds (Optional[shapely.Polygon | shapely.MultiPolygon], optional):
151 | Only meaningful if CRS is set. In that case, it represents the spatial bounds of the
152 | prediction region. If pixel_to_CRS_transform is None, and both pixel_- and
153 | geospatial_prediction_bounds are not None, then the two bounds will be used to
154 | compute the transform. Defaults to None.
155 | """
156 | # Build a geopandas dataframe containing the geometries, additional attributes, and CRS
157 | self.detections = gpd.GeoDataFrame(
158 | data=data, geometry=detection_geometries, crs=CRS
159 | )
160 |
161 | # If the pixel_to_CRS_transform is None but can be computed from the two bounds, do that
162 | if (
163 | input_in_pixels
164 | and (pixel_to_CRS_transform is None)
165 | and (pixel_prediction_bounds is not None)
166 | and (geospatial_prediction_bounds is not None)
167 | ):
168 | # We assume that these two array are exactly corresponding, representing the same shape
169 | # in the two coordinate frames and also the same starting vertex.
170 | # Drop the last entry because it is a duplicate of the first one
171 | geospatial_corners_array = shapely.get_coordinates(
172 | geospatial_prediction_bounds
173 | )[:-1]
174 | pixel_corners_array = shapely.get_coordinates(pixel_prediction_bounds)[:-1]
175 |
176 | # If they don't have the same number of vertices, this can't be the case
177 | if len(geospatial_corners_array) != len(pixel_corners_array):
178 | raise ValueError("Bounds had different lengths")
179 |
180 | # Representing the correspondences as ground control points
181 | ground_control_points = [
182 | rasterio.control.GroundControlPoint(
183 | col=pixel_vertex[0],
184 | row=pixel_vertex[1],
185 | x=geospatial_vertex[0],
186 | y=geospatial_vertex[1],
187 | )
188 | for pixel_vertex, geospatial_vertex in zip(
189 | pixel_corners_array, geospatial_corners_array
190 | )
191 | ]
192 | # Solve the affine transform that best transforms from the pixel to geospatial coordinates
193 | pixel_to_CRS_transform = rasterio.transform.from_gcps(ground_control_points)
194 |
195 | # Error checking
196 | if (pixel_to_CRS_transform is None) and (CRS is not None) and input_in_pixels:
197 | raise ValueError(
198 | "The input was in pixels and a CRS was specified but no geommetric transformation was provided to transform the pixel values to that CRS"
199 | )
200 |
201 | # Set the transform
202 | self.pixel_to_CRS_transform = pixel_to_CRS_transform
203 |
204 | # If the inputs are provided in pixels, apply the transform to the predictions
205 | if input_in_pixels:
206 | # Get the transform in the format expected by shapely
207 | shapely_transform = pixel_to_CRS_transform.to_shapely()
208 | # Apply this transformation to the geometry of the dataframe
209 | self.detections.geometry = self.detections.geometry.affine_transform(
210 | matrix=shapely_transform
211 | )
212 |
213 | # Handle the bounds
214 | # If the bounds are provided as geospatial coordinates, use those directly
215 | if geospatial_prediction_bounds is not None:
216 | prediction_bounds_in_CRS = geospatial_prediction_bounds
217 | # If the bounds are provided in pixels and a transform to geospatial is provided, use that
218 | # this assumes that a CRS is set based on previous checks
219 | elif (pixel_to_CRS_transform is not None) and (
220 | pixel_prediction_bounds is not None
221 | ):
222 | prediction_bounds_in_CRS = affine_transform(
223 | geom=pixel_prediction_bounds, matrix=pixel_to_CRS_transform.to_shapely()
224 | )
225 | # If there is no CRS and pixel bounds are provided, use these directly
226 | # The None CRS implies pixels, so this still has the intended meaning
227 | elif CRS is None and pixel_prediction_bounds is not None:
228 | prediction_bounds_in_CRS = pixel_prediction_bounds
229 | # Else set the bounds to None (unknown)
230 | else:
231 | prediction_bounds_in_CRS = None
232 |
233 | # Create a one-length geoseries for the bounds
234 | self.prediction_bounds_in_CRS = gpd.GeoSeries(
235 | data=[prediction_bounds_in_CRS], crs=CRS
236 | )
237 |
238 | def subset_detections(self, detection_indices) -> "RegionDetections":
239 | """Return a new RegionDetections object with only the detections indicated by the indices
240 |
241 | Args:
242 | detection_indices:
243 | Which detections to include. Can be any type that can be passed to pd.iloc.
244 |
245 | Returns:
246 | RegionDetections: The subset of detections cooresponding to these indeices
247 | """
248 | # Create a deep copy of the object
249 | subset_rd = copy.deepcopy(self)
250 | # Subset the detections dataframe to the requested rows
251 | subset_rd.detections = subset_rd.detections.iloc[detection_indices, :]
252 |
253 | return subset_rd
254 |
255 | def save(self, save_path: PATH_TYPE):
256 | """Saves the information to disk
257 |
258 | Args:
259 | save_path (PATH_TYPE):
260 | Path to a geofile to save the data to. The containing folder will be created if it
261 | doesn't exist.
262 | """
263 | # Convert to a Path object and create the containing folder if not present
264 | save_path = Path(save_path)
265 | save_path.parent.mkdir(parents=True, exist_ok=True)
266 |
267 | # Save the detections to a file. Note that the bounds of the prediction region and
268 | # information about the transform to pixel coordinates are currently lost.
269 | self.detections.to_file(save_path)
270 |
271 | def get_data_frame(
272 | self, CRS: Optional[pyproj.CRS] = None, as_pixels: Optional[bool] = False
273 | ) -> gpd.GeoDataFrame:
274 | """Get the detections, optionally specifying a CRS or pixel coordinates
275 |
276 | Args:
277 | CRS (Optional[pyproj.CRS], optional):
278 | Requested CRS for the output detections. If un-set, the CRS of self.detections will
279 | be used. Defaults to None.
280 | as_pixels (Optional[bool], optional):
281 | Whether to return the values in pixel coordinates. Defaults to False.
282 |
283 | Returns:
284 | gpd.GeoDataFrame: Detections in the requested CRS or in pixel coordinates with a None .crs
285 | """
286 | # If the data is requested in pixel coordinates, transform it appropriately
287 | if as_pixels:
288 | if (self.detections.crs is not None) and (
289 | self.pixel_to_CRS_transform is None
290 | ):
291 | raise ValueError(
292 | "Pixel coordinates were requested but data is in geospatial units with no transformation to pixels"
293 | )
294 |
295 | # Compute the inverse transform using ~ to map from the CRS to pixels instead of the
296 | # other way around. Also get this transform in the shapely convention.
297 | CRS_to_pixel_transform_shapely = (~self.pixel_to_CRS_transform).to_shapely()
298 | # Create a new geodataframe with the transformed coordinates
299 | # Start by copying the old dataframe
300 | pixel_coordinate_detections = self.detections.copy()
301 | # Since the units are pixels, it no longer has a CRS, so set it to None
302 | pixel_coordinate_detections.crs = None
303 | # Transform the geometry to pixel coordinates
304 | pixel_coordinate_detections.geometry = (
305 | pixel_coordinate_detections.geometry.affine_transform(
306 | CRS_to_pixel_transform_shapely
307 | )
308 | )
309 | return pixel_coordinate_detections
310 |
311 | # Return the data in geospatial coordinates
312 | else:
313 | # If no CRS is specified, return the data as-is, using the current CRS
314 | if CRS is None:
315 | return self.detections.copy()
316 |
317 | # Transform the data to the requested CRS. Note that if no CRS is provided initially,
318 | # this will error out
319 | detections_in_new_CRS = self.detections.copy().to_crs(CRS)
320 | return detections_in_new_CRS
321 |
322 | def convert_to_bboxes(self) -> "RegionDetections":
323 | """
324 | Return a copy of the RD with the geometry of all detections replaced by the minimum
325 | axis-aligned bounding rectangle. Also, rows with an empty geometry are removed.
326 |
327 | Returns:
328 | RegionDetections: An identical RD with all non-empty geometries represented as bounding boxes.
329 | """
330 | # Get the detections
331 | detections_df = self.get_data_frame()
332 | # Get non-empty rows in the dataframe, since conversion to bounding box only works for
333 | # non-empty polygons
334 | nonempty_rows = detections_df[~detections_df.geometry.is_empty]
335 | # Get the bounds and convert to shapely boxes
336 | bounds = nonempty_rows.bounds
337 | boxes = shapely.box(
338 | xmin=bounds.minx, ymin=bounds.miny, xmax=bounds.maxx, ymax=bounds.maxy
339 | )
340 | # Update the geometry
341 | # TODO make sure that thisn't updating the geometry of the orignal one
342 | nonempty_rows.geometry = boxes
343 | # Create a new RegionDetections object and update the detections on it
344 | bbox_rd = copy.deepcopy(self)
345 | bbox_rd.detections = nonempty_rows
346 | return bbox_rd
347 |
348 | def update_geometry_column(self, geometry_column: str) -> "RegionDetections":
349 | """Update the geometry to another column in the dataframe that contains shapely data
350 |
351 | Args:
352 | geometry_column (str): The name of a column containing shapely data
353 |
354 | Returns:
355 | RegionDetections: An updated RD with the geometry specified by the data in `geometry_column`
356 | """
357 | # Create a copy of the detections
358 | detections_df = self.get_data_frame().copy()
359 | # Set the geometry column to the specified one
360 | detections_df.geometry = detections_df[geometry_column]
361 |
362 | # Create a copy of the RD
363 | updated_geometry_rd = copy.deepcopy(self)
364 | # Update the detections and return
365 | updated_geometry_rd.detections = detections_df
366 |
367 | return updated_geometry_rd
368 |
369 | def get_bounds(
370 | self, CRS: Optional[pyproj.CRS] = None, as_pixels: Optional[bool] = False
371 | ) -> gpd.GeoSeries:
372 | if CRS is None:
373 | # Get bounds in original CRS
374 | bounds = self.prediction_bounds_in_CRS.copy()
375 | else:
376 | # Get bounds in requested CRS
377 | bounds = self.prediction_bounds_in_CRS.to_crs(CRS)
378 |
379 | if as_pixels:
380 | # Invert the transform
381 | CRS_to_pixel_transform_shapely = (~self.pixel_to_CRS_transform).to_shapely()
382 | # Transform the bounds into pixel coordinates
383 | bounds.geometry = bounds.geometry.affine_transform(
384 | CRS_to_pixel_transform_shapely
385 | )
386 | # Set the CRS to None since this is now pixels
387 | # Note that this does not reproject
388 | bounds.set_crs(None)
389 |
390 | return bounds
391 |
392 | def get_CRS(self) -> Union[pyproj.CRS, None]:
393 | """Return the CRS of the detections dataframe
394 |
395 | Returns:
396 | Union[pyproj.CRS, None]: The CRS for the detections
397 | """
398 | return self.detections.crs
399 |
400 | def plot(
401 | self,
402 | CRS: Optional[pyproj.CRS] = None,
403 | as_pixels: bool = False,
404 | plt_ax: Optional[plt.axes] = None,
405 | plt_show: bool = True,
406 | visualization_column: Optional[str] = None,
407 | bounds_color: Optional[Union[str, np.array, pd.Series]] = None,
408 | detection_kwargs: dict = {},
409 | bounds_kwargs: dict = {},
410 | raster_file: Optional[PATH_TYPE] = None,
411 | raster_vis_downsample: float = 10.0,
412 | ) -> plt.axes:
413 | """Plot the detections and the bounds of the region
414 |
415 | Args:
416 | CRS (Optional[pyproj.CRS], optional):
417 | What CRS to use. Defaults to None.
418 | as_pixels (bool, optional):
419 | Whether to display in pixel coordinates. Defaults to False.
420 | plt_ax (Optional[plt.axes], optional):
421 | A pyplot axes to plot on. If not provided, one will be created. Defaults to None.
422 | plt_show (bool, optional):
423 | Whether to plot the result or just return it. Defaults to True.
424 | visualization_column (Optional[str], optional):
425 | Which column to visualize from the detections dataframe. Defaults to None.
426 | bounds_color (Optional[Union[str, np.array, pd.Series]], optional):
427 | The color to plot the bounds. Must be accepted by the gpd.plot color argument.
428 | Defaults to None.
429 | detection_kwargs (dict, optional):
430 | Additional keyword arguments to pass to the .plot method for the detections.
431 | Defaults to {}.
432 | bounds_kwargs (dict, optional):
433 | Additional keyword arguments to pass to the .plot method for the bounds.
434 | Defaults to {}.
435 | raster_file (Optional[PATH_TYPE], optional):
436 | A path to a raster file to visualize the detections over if provided. Defaults to None.
437 | raster_vis_downsample (float, optional):
438 | The raster file is downsampled by this fraction before visualization to avoid
439 | excessive memory use or plotting time. Defaults to 10.0.
440 |
441 | Returns:
442 | plt.axes: The axes that were plotted on
443 | """
444 |
445 | # Get the dataframe and the bounds
446 | data_frame = self.get_data_frame(CRS=CRS, as_pixels=as_pixels)
447 | bounds = self.get_bounds(CRS, as_pixels=as_pixels)
448 |
449 | # Perform plotting and return the axes
450 | plot_detections(
451 | data_frame=data_frame,
452 | bounds=bounds,
453 | CRS=data_frame.crs,
454 | plt_ax=plt_ax,
455 | plt_show=plt_show,
456 | visualization_column=visualization_column,
457 | detection_kwargs=detection_kwargs,
458 | bounds_kwargs=bounds_kwargs,
459 | raster_file=raster_file,
460 | raster_vis_downsample=raster_vis_downsample,
461 | bounds_color=bounds_color,
462 | )
463 |
464 |
465 | class RegionDetectionsSet:
466 | region_detections: List[RegionDetections]
467 |
468 | def __init__(self, region_detections: List[RegionDetections]):
469 | """Create a set of detections to conveniently perform operations on all of them, for example
470 | merging all regions into a single dataframe with an additional column indicating which
471 | region each detection belongs to.
472 |
473 | Args:
474 | region_detections (List[RegionDetections]): A list of individual detections
475 | """
476 | self.region_detections = region_detections
477 |
478 | def all_regions_have_CRS(self) -> bool:
479 | """Check whether all sub-regions have a non-None CRS
480 |
481 | Returns:
482 | bool: Whether all sub-regions have a valid CRS.
483 | """
484 | # Get the CRS for each sub-region
485 | regions_CRS_values = [rd.detections.crs for rd in self.region_detections]
486 | # Only valid if no CRS value is None
487 | valid = not (None in regions_CRS_values)
488 |
489 | return valid
490 |
491 | def get_default_CRS(self, check_all_have_CRS=True) -> pyproj.CRS:
492 | """Find the CRS of the first sub-region to use as a default
493 |
494 | Args:
495 | check_all_have_CRS (bool, optional):
496 | Should an error be raised if not all regions have a CRS set.
497 |
498 | Returns:
499 | pyproj.CRS: The CRS given by the first sub-region.
500 | """
501 | if check_all_have_CRS and not self.all_regions_have_CRS():
502 | raise ValueError(
503 | "Not all regions have a CRS set and a default one was requested"
504 | )
505 | # Check that every region is geospatial
506 | regions_CRS_values = [rd.detections.crs for rd in self.region_detections]
507 | # The default is the to the CRS of the first region
508 | # TODO in the future it could be something else like the most common
509 | CRS = regions_CRS_values[0]
510 |
511 | return CRS
512 |
513 | def get_region_detections(self, index: int) -> RegionDetections:
514 | """Get a single region detections object by index
515 |
516 | Args:
517 | index (int): Which one to select
518 |
519 | Returns:
520 | RegionDetections: The RegionDetections object from the list of objects in the set
521 | """
522 | return self.region_detections[index]
523 |
524 | def merge(
525 | self,
526 | region_ID_key: Optional[str] = "region_ID",
527 | CRS: Optional[pyproj.CRS] = None,
528 | ):
529 | """Get the merged detections across all regions with an additional field specifying which region
530 | the detection came from.
531 |
532 | Args:
533 | region_ID_key (Optional[str], optional):
534 | Create this column in the output dataframe identifying which region that data came
535 | from using a zero-indexed integer. Defaults to "region_ID".
536 | CRS (Optional[pyproj.CRS], optional):
537 | Requested CRS for merged detections. If un-set, the CRS of the first region will
538 | be used. Defaults to None.
539 |
540 | Returns:
541 | gpd.GeoDataFrame: Detections in the requested CRS or in pixel coordinates with a None .crs
542 | """
543 |
544 | # Get the detections from each region detection object as geodataframes
545 | detection_geodataframes = [
546 | rd.get_data_frame(CRS=CRS) for rd in self.region_detections
547 | ]
548 |
549 | # Add a column to each geodataframe identifying which region detection object it came from
550 | # Note that dataframes in the original list are updated
551 | # TODO consider a more sophisticated ID
552 | for ID, gdf in enumerate(detection_geodataframes):
553 | gdf[region_ID_key] = ID
554 |
555 | # Concatenate the geodataframes together
556 | concatenated_geodataframes = pd.concat(detection_geodataframes)
557 |
558 | ## Merge_bounds
559 | # Get the merged bounds
560 | merged_bounds = self.get_bounds(CRS=CRS)
561 | # Convert to a single shapely object
562 | merged_bounds_shapely = merged_bounds.geometry[0]
563 |
564 | # Use the geometry column of the concatenated dataframes
565 | merged_region_detections = RegionDetections(
566 | detection_geometries="geometry",
567 | data=concatenated_geodataframes,
568 | CRS=CRS,
569 | input_in_pixels=False,
570 | geospatial_prediction_bounds=merged_bounds_shapely,
571 | )
572 |
573 | return merged_region_detections
574 |
575 | def get_data_frame(
576 | self,
577 | CRS: Optional[pyproj.CRS] = None,
578 | merge: bool = False,
579 | region_ID_key: str = "region_ID",
580 | ) -> gpd.GeoDataFrame | List[gpd.GeoDataFrame]:
581 | """Get the detections, optionally specifying a CRS
582 |
583 | Args:
584 | CRS (Optional[pyproj.CRS], optional):
585 | Requested CRS for the output detections. If un-set, the CRS of self.detections will
586 | be used. Defaults to None.
587 | merge (bool, optional):
588 | If true, return one dataframe. Else, return a list of individual dataframes.
589 | region_ID_key (str, optional):
590 | Use this column to identify which region each detection came from. Defaults to
591 | "region_ID"
592 |
593 | Returns:
594 | gpd.GeoDataFrame | List[gpd.GeoDataFrame]:
595 | If merge=True, then one dataframe with an addtional column specifying which region each
596 | detection came from. If merge=False, then a list of dataframes for each region.
597 | """
598 | if merge:
599 | # Merge all of the detections into one RegionDetection
600 | merged_detections = self.merge(region_ID_key=region_ID_key, CRS=CRS)
601 | # get the dataframe. It is already in the requested CRS in the current implementation.
602 | data_frame = merged_detections.get_data_frame()
603 | return data_frame
604 |
605 | # Get a list of dataframes from each region
606 | list_of_region_data_frames = [
607 | rd.get_data_frame(CRS=CRS) for rd in self.region_detections
608 | ]
609 | return list_of_region_data_frames
610 |
611 | def convert_to_bboxes(self) -> "RegionDetectionsSet":
612 | """Convert all the RegionDetections to bounding box representations
613 |
614 | Returns:
615 | RegionDetectionsSet:
616 | A new RDS where each RD has all empty geometries dropped and all remaning ones
617 | represented by an axis-aligned rectangle.
618 | """
619 | # Convert each detection
620 | converted_detections = [rd.convert_to_bboxes() for rd in self.region_detections]
621 | # Return the new RDS
622 | bboxes_rds = RegionDetectionsSet(converted_detections)
623 | return bboxes_rds
624 |
625 | def update_geometry_column(self, geometry_column: str) -> "RegionDetectionsSet":
626 | """
627 | Update the geometry to another column in the dataframe that contains shapely data for each RD
628 |
629 | Args:
630 | geometry_column (str): The name of a column containing shapely data
631 |
632 | Returns:
633 | RegionDetectionsSet: An updated RDS with the geometry specified by the data in `geometry_column`
634 | """
635 | # Convert each detection
636 | converted_detections = [
637 | rd.update_geometry_column(geometry_column=geometry_column)
638 | for rd in self.region_detections
639 | ]
640 | # Return the new RDS
641 | updated_geometry_rds = RegionDetectionsSet(converted_detections)
642 | return updated_geometry_rds
643 |
644 | def get_bounds(
645 | self, CRS: Optional[pyproj.CRS] = None, union_bounds: bool = True
646 | ) -> gpd.GeoSeries:
647 | """Get the bounds corresponding to the sub-regions.
648 |
649 | Args:
650 | CRS (Optional[pyproj.CRS], optional):
651 | The CRS to return the bounds in. If not set, it will be the bounds of the first
652 | region. Defaults to None.
653 | union_bounds (bool, optional):
654 | Whether to return the spatial union of all bounds or a series of per-region bounds.
655 | Defaults to True.
656 |
657 | Returns:
658 | gpd.GeoSeries: Either a one-length series of merged bounds if merge=True or a series
659 | of bounds per region.
660 | """
661 |
662 | region_bounds = [rd.get_bounds(CRS=CRS) for rd in self.region_detections]
663 | # Create a geodataframe out of these region bounds
664 | all_region_bounds = gpd.GeoSeries(pd.concat(region_bounds), crs=CRS)
665 |
666 | # If the union is not requested, return the individual bounds
667 | if not union_bounds:
668 | return all_region_bounds
669 |
670 | # Compute the union of all bounds
671 | merged_bounds = gpd.GeoSeries([all_region_bounds.geometry.union_all()], crs=CRS)
672 |
673 | return merged_bounds
674 |
675 | def disjoint_bounds(self) -> bool:
676 | """Determine whether the bounds of the sub-regions are disjoint
677 |
678 | Returns:
679 | bool: Are they disjoint
680 | """
681 | # Get the bounds for each individual region
682 | bounds = self.get_bounds(union_bounds=False)
683 | # Get the union of all bounds
684 | union_bounds = bounds.union_all()
685 |
686 | # Find the sum of areas for each region
687 | sum_individual_areas = bounds.area.sum()
688 | # And the area of the union
689 | union_area = union_bounds.area
690 |
691 | # If the two areas are the same (down to numeric errors) then there are no overlaps
692 | disjoint = np.allclose(sum_individual_areas, union_area)
693 |
694 | return disjoint
695 |
696 | def save(
697 | self,
698 | save_path: PATH_TYPE,
699 | CRS: Optional[pyproj.CRS] = None,
700 | region_ID_key: Optional[str] = "region_ID",
701 | ):
702 | """
703 | Save the data to a geospatial file by calling get_data_frame with merge=True and then saving
704 | to the specified file. The containing folder is created if it doesn't exist.
705 |
706 | Args:
707 | save_path (PATH_TYPE):
708 | File to save the data to. The containing folder will be created if it does not exist.
709 | CRS (Optional[pyproj.CRS], optional):
710 | See get_data_frame.
711 | region_ID_key (Optional[str], optional):
712 | See get_data_frame.
713 | """
714 | # Get the concatenated dataframes
715 | concatenated_geodataframes = self.get_data_frame(
716 | CRS=CRS, region_ID_key=region_ID_key, merge=True
717 | )
718 |
719 | # Ensure that the folder to save them to exists
720 | save_path = Path(save_path)
721 | save_path.parent.mkdir(exist_ok=True, parents=True)
722 |
723 | # Save the data to the geofile
724 | concatenated_geodataframes.to_file(save_path)
725 |
726 | def plot(
727 | self,
728 | CRS: Optional[pyproj.CRS] = None,
729 | plt_ax: Optional[plt.axes] = None,
730 | plt_show: bool = True,
731 | visualization_column: Optional[str] = None,
732 | bounds_color: Optional[Union[str, np.array, pd.Series]] = None,
733 | detection_kwargs: dict = {},
734 | bounds_kwargs: dict = {},
735 | raster_file: Optional[PATH_TYPE] = None,
736 | raster_vis_downsample: float = 10.0,
737 | ) -> plt.axes:
738 | """Plot each of the region detections using their .plot method
739 |
740 | Args:
741 | CRS (Optional[pyproj.CRS], optional):
742 | The CRS to use for plotting all regions. If unset, the default one for this object
743 | will be selected. Defaults to None.
744 | plt_ax (Optional[plt.axes], optional):
745 | The axes to plot on. Will be created if not provided. Defaults to None.
746 | plt_show (bool, optional):
747 | See RegionDetections.plot. Defaults to True.
748 | visualization_column (Optional[str], optional):
749 | See regiondetections.plot. Defaults to None.
750 | bounds_color (Optional[Union[str, np.array, pd.Series]], optional):
751 | See regiondetections.plot. Defaults to None.
752 | detection_kwargs (dict, optional):
753 | See regiondetections.plot. Defaults to {}.
754 | bounds_kwargs (dict, optional):
755 | See regiondetections.plot. Defaults to {}.
756 | raster_file (Optional[PATH_TYPE], optional):
757 | See regiondetections.plot. Defaults to None.
758 | raster_vis_downsample (float, optional):
759 | See regiondetections.plot. Defaults to 10.0.
760 |
761 | Returns:
762 | plt.axes: The axes that have been plotted on.
763 | """
764 | # Extract the bounds for each of the sub-regions
765 | bounds = self.get_bounds(CRS=CRS, union_bounds=False)
766 | data_frame = self.get_data_frame(CRS=CRS, merge=True)
767 |
768 | # Perform plotting and return the axes
769 | return plot_detections(
770 | data_frame=data_frame,
771 | bounds=bounds,
772 | CRS=data_frame.crs,
773 | plt_ax=plt_ax,
774 | plt_show=plt_show,
775 | visualization_column=visualization_column,
776 | bounds_color=bounds_color,
777 | detection_kwargs=detection_kwargs,
778 | bounds_kwargs=bounds_kwargs,
779 | raster_file=raster_file,
780 | raster_vis_downsample=raster_vis_downsample,
781 | )
782 |
--------------------------------------------------------------------------------
/tree_detection_framework/entrypoints/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/entrypoints/__init__.py
--------------------------------------------------------------------------------
/tree_detection_framework/entrypoints/generate_predictions.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | from typing import Optional
4 |
5 | import pyproj
6 | import torch
7 |
8 | from tree_detection_framework.constants import BOUNDARY_TYPE, PATH_TYPE
9 | from tree_detection_framework.detection.detector import (
10 | DeepForestDetector,
11 | Detectree2Detector,
12 | )
13 | from tree_detection_framework.detection.models import DeepForestModule, Detectree2Module
14 | from tree_detection_framework.postprocessing.postprocessing import multi_region_NMS
15 | from tree_detection_framework.preprocessing.preprocessing import create_dataloader
16 |
17 |
18 | def generate_predictions(
19 | raster_folder_path: PATH_TYPE,
20 | chip_size: float,
21 | tree_detection_model: str,
22 | chip_stride: Optional[float] = None,
23 | chip_overlap_percentage: float = None,
24 | use_units_meters: bool = False,
25 | region_of_interest: Optional[BOUNDARY_TYPE] = None,
26 | output_resolution: Optional[float] = None,
27 | output_CRS: Optional[pyproj.CRS] = None,
28 | predictions_save_path: Optional[PATH_TYPE] = None,
29 | view_predictions_plot: bool = False,
30 | run_nms: bool = True,
31 | iou_threshold: Optional[float] = 0.3,
32 | min_confidence: Optional[float] = 0.3,
33 | batch_size: int = 1,
34 | ):
35 | """
36 | Entrypoint script to generate tree detections for a raster dataset input. Supports visualizing and saving predictions.
37 |
38 | Args:
39 | raster_folder_path (PATH_TYPE): Path to the folder or raster files.
40 | chip_size (float):
41 | Dimension of the chip. May be pixels or meters, based on `use_units_meters`.
42 | chip_stride (Optional[float], optional):
43 | Stride of the chip. May be pixels or meters, based on `use_units_meters`. If used,
44 | `chip_overlap_percentage` should not be set. Defaults to None.
45 | tree_detection_model (str):
46 | Selected model for detecting trees.
47 | chip_overlap_percentage (Optional[float], optional):
48 | Percent overlap of the chip from 0-100. If used, `chip_stride` should not be set.
49 | Defaults to None.
50 | use_units_meters (bool, optional):
51 | Use units of meters rather than pixels when interpreting the `chip_size` and `chip_stride`.
52 | Defaults to False.
53 | region_of_interest (Optional[BOUNDARY_TYPE], optional):
54 | Only data from this spatial region will be included in the dataloader. Defaults to None.
55 | output_resolution (Optional[float], optional):
56 | Spatial resolution the data in meters/pixel. If un-set, will be the resolution of the
57 | first raster data that is read. Defaults to None.
58 | output_CRS: (Optional[pyproj.CRS], optional):
59 | The coordinate reference system to use for the output data. If un-set, will be the CRS
60 | of the first tile found. Defaults to None.
61 | predictions_save_path (Optional[PATH_TYPE], optional):
62 | Path to a geofile to save the prediction outputs.
63 | view_predictions_plot (bool, optional):
64 | Set to True if visualization of the detected regions is needed. Defaults to False.
65 | run_nms: (bool, optional):
66 | Set to True if non-max suppresion needs to be run on predictions from multiple regions.
67 | iou_threshold (float, optional):
68 | What intersection over union value to consider an overlapping detection. Defaults to 0.5.
69 | min_confidence (float, optional):
70 | Prediction score threshold for detections to be included.
71 | batch_size (int, optional):
72 | Number of images to load in a batch. Defaults to 1.
73 | """
74 |
75 | # Create the dataloader by passing folder path to raster data.
76 | dataloader = create_dataloader(
77 | raster_folder_path=raster_folder_path,
78 | chip_size=chip_size,
79 | chip_stride=chip_stride,
80 | chip_overlap_percentage=chip_overlap_percentage,
81 | use_units_meters=use_units_meters,
82 | region_of_interest=region_of_interest,
83 | output_resolution=output_resolution,
84 | output_CRS=output_CRS,
85 | batch_size=batch_size,
86 | )
87 |
88 | # Setup the specified tree detection model
89 | if tree_detection_model == "deepforest":
90 |
91 | # Setup the parameters dictionary
92 | param_dict = {
93 | "backbone": "retinanet",
94 | "num_classes": 1,
95 | }
96 |
97 | df_module = DeepForestModule(param_dict)
98 | # Move the module to the GPU if available
99 | df_module.to(
100 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
101 | )
102 | lightning_detector = DeepForestDetector(df_module)
103 |
104 | elif tree_detection_model == "detectree2":
105 |
106 | # Load detectree2 pretrained weights
107 | # TODO: download pretrained weights when called, instead of providing a local path
108 | trained_model = (
109 | "/ofo-share/repos-amritha/detectree2-code/230103_randresize_full.pth"
110 | )
111 | param_dict = {"update_model": trained_model}
112 |
113 | dtree2_module = Detectree2Module(param_dict)
114 | lightning_detector = Detectree2Detector(dtree2_module)
115 |
116 | else:
117 | raise ValueError(
118 | """Please enter a valid tree detection model. Currently supported models are:
119 | 1. deepforest
120 | 2. detectree2"""
121 | )
122 |
123 | # Get predictions by invoking the tree_detection_model
124 | logging.info("Getting tree detections")
125 | outputs = lightning_detector.predict(dataloader)
126 |
127 | if run_nms is True:
128 | logging.info("Running non-max suppression")
129 | # Run non-max suppression on the detected regions
130 | outputs = multi_region_NMS(
131 | outputs, iou_theshold=iou_threshold, min_confidence=min_confidence
132 | )
133 |
134 | if predictions_save_path:
135 | # Save predictions to disk
136 | outputs.save(predictions_save_path)
137 |
138 | if view_predictions_plot is True:
139 | logging.info("View plot. Kill the plot window to exit.")
140 | # Plot the detections and the bounds of the region
141 | outputs.plot()
142 |
143 |
144 | def parse_args() -> argparse.Namespace:
145 | description = (
146 | "This script generates tree detections for a given raster image. First, it creates a dataloader "
147 | + "with the tiled raster dataset and provides the images as input to the selected tree detection model. "
148 | + "All of the arguments are passed to "
149 | + "tree_detection_framework.entrypoints.generate_predictions "
150 | + "which has the following documentation:\n\n"
151 | + generate_predictions.__doc__
152 | )
153 | parser = argparse.ArgumentParser(
154 | description=description, formatter_class=argparse.RawDescriptionHelpFormatter
155 | )
156 |
157 | parser.add_argument("--raster-folder-path", required=True)
158 | parser.add_argument("--chip-size", type=float, required=True)
159 | parser.add_argument("--tree-detection-model", type=str, required=True)
160 | parser.add_argument("--chip-stride", type=float)
161 | parser.add_argument("--chip-overlap-percentage", type=float)
162 | parser.add_argument("--use-units-meters", action="store_true")
163 | parser.add_argument("--region-of-interest")
164 | parser.add_argument("--output-resolution", type=float)
165 | parser.add_argument("--output-CRS")
166 | parser.add_argument("--predictions-save-path")
167 | parser.add_argument("--view-predictions-plot", action="store_true")
168 | parser.add_argument("--run-nms", action="store_true")
169 | parser.add_argument("--iou-threshold", type=float, default=0.3)
170 | parser.add_argument("--min-confidence", type=float, default=0.3)
171 | parser.add_argument("--batch-size", type=int, default=1)
172 |
173 | try:
174 | args = parser.parse_args()
175 |
176 | except SystemExit as e:
177 | print("\nError: Missing required arguments.")
178 | parser.print_help()
179 | raise e
180 |
181 | return args
182 |
183 |
184 | if __name__ == "__main__":
185 | args = parse_args()
186 | generate_predictions(**args.__dict__)
187 |
--------------------------------------------------------------------------------
/tree_detection_framework/entrypoints/tile_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Optional
3 |
4 | import pyproj
5 |
6 | from tree_detection_framework.constants import ARRAY_TYPE, BOUNDARY_TYPE, PATH_TYPE
7 | from tree_detection_framework.preprocessing.preprocessing import (
8 | create_dataloader,
9 | save_dataloader_contents,
10 | visualize_dataloader,
11 | )
12 |
13 |
14 | def tile_data(
15 | raster_folder_path: PATH_TYPE,
16 | chip_size: float,
17 | chip_stride: Optional[float] = None,
18 | chip_overlap_percentage: float = None,
19 | use_units_meters: bool = False,
20 | region_of_interest: Optional[BOUNDARY_TYPE] = None,
21 | output_resolution: Optional[float] = None,
22 | output_CRS: Optional[pyproj.CRS] = None,
23 | vector_label_folder_path: Optional[PATH_TYPE] = None,
24 | vector_label_attribute: Optional[str] = None,
25 | visualize_n_tiles: Optional[int] = None,
26 | save_folder: Optional[PATH_TYPE] = None,
27 | save_n_tiles: Optional[int] = None,
28 | random_sample: bool = False,
29 | batch_size: int = 1,
30 | ):
31 | """
32 | Entrypoint script for testing preprocessing functions.
33 | It enables creating a dataloader, visualizing sample tiles from the dataloader, and saving the contents of the dataloader to disk.
34 |
35 | Args:
36 | raster_folder_path (PATH_TYPE): Path to the folder or raster files.
37 | chip_size (float):
38 | Dimension of the chip. May be pixels or meters, based on `use_units_meters`.
39 | chip_stride (Optional[float], optional):
40 | Stride of the chip. May be pixels or meters, based on `use_units_meters`. If used,
41 | `chip_overlap_percentage` should not be set. Defaults to None.
42 | chip_overlap_percentage (Optional[float], optional):
43 | Percent overlap of the chip from 0-100. If used, `chip_stride` should not be set.
44 | Defaults to None.
45 | use_units_meters (bool, optional):
46 | Use units of meters rather than pixels when interpreting the `chip_size` and `chip_stride`.
47 | Defaults to False.
48 | region_of_interest (Optional[BOUNDARY_TYPE], optional):
49 | Only data from this spatial region will be included in the dataloader. Defaults to None.
50 | output_resolution (Optional[float], optional):
51 | Spatial resolution the data in meters/pixel. If un-set, will be the resolution of the
52 | first raster data that is read. Defaults to None.
53 | output_CRS: (Optional[pyproj.CRS], optional):
54 | The coordinate reference system to use for the output data. If un-set, will be the CRS
55 | of the first tile found. Defaults to None.
56 | vector_label_folder_path (Optional[PATH_TYPE], optional):
57 | A folder of geospatial vector files that will be used for the label. If un-set, the
58 | dataloader will not be labeled. Defaults to None.
59 | vector_label_attribute (Optional[str], optional):
60 | Attribute to read from the vector data, such as the class or instance ID. Defaults to None.
61 | visualize_n_tiles (int):
62 | The number of randomly-sampled tiles to display.
63 | save_folder (Optional[PATH_TYPE], optional):
64 | Folder to save data to. Will be created if it doesn't exist.
65 | save_n_tiles (Optional[int], optional):
66 | Number of tiles to save. Whether they are the first tiles or random is controlled by
67 | `random_sample`. If unset, all tiles will be saved. Defaults to None.
68 | random_sample (bool, optional):
69 | If `save_n_tiles` is set, should the tiles be randomly sampled rather than taken from the
70 | beginning of the dataloader. Defaults to False.
71 | batch_size (int, optional):
72 | Number of images to load in a batch. Defaults to 1.
73 | """
74 | # Create the dataloader by passing folder path to raster data and optionally a path to the vector data folder.
75 | dataloader = create_dataloader(
76 | raster_folder_path=raster_folder_path,
77 | chip_size=chip_size,
78 | chip_stride=chip_stride,
79 | chip_overlap_percentage=chip_overlap_percentage,
80 | use_units_meters=use_units_meters,
81 | region_of_interest=region_of_interest,
82 | output_resolution=output_resolution,
83 | output_CRS=output_CRS,
84 | vector_label_folder_path=vector_label_folder_path,
85 | vector_label_attribute=vector_label_attribute,
86 | batch_size=batch_size,
87 | )
88 |
89 | # If `visualize_n_tiles` is specified, display those many number of tiles.
90 | if visualize_n_tiles is not None:
91 | visualize_dataloader(dataloader=dataloader, n_tiles=visualize_n_tiles)
92 |
93 | # If path to save tiles is given, save all the tiles (or `n_tiles`) from the dataloader to disk. Tiles can be randomly sampled or ordered.
94 | if save_folder is not None:
95 | save_dataloader_contents(
96 | dataloader=dataloader,
97 | save_folder=save_folder,
98 | n_tiles=save_n_tiles,
99 | random_sample=random_sample,
100 | )
101 |
102 |
103 | def parse_args() -> argparse.Namespace:
104 | parser = argparse.ArgumentParser(description="Chipping orthomosaic images")
105 |
106 | parser.add_argument(
107 | "--raster-folder-path",
108 | type=str,
109 | required=True,
110 | help="Path to the folder or raster files.",
111 | )
112 |
113 | parser.add_argument(
114 | "--chip-size",
115 | type=float,
116 | required=True,
117 | help="Dimension of the chip. May be pixels or meters, based on --use-units-meters.",
118 | )
119 |
120 | parser.add_argument(
121 | "--chip-stride",
122 | type=float,
123 | required=False,
124 | help="Stride of the chip. May be pixels or meters, based on --use-units-meters. If used, --chip-overlap-percentage should not be set.",
125 | )
126 |
127 | parser.add_argument(
128 | "--chip-overlap-percentage",
129 | type=float,
130 | required=False,
131 | help="Percent overlap of the chip from 0-100. If used, --chip-stride should not be set.",
132 | )
133 |
134 | parser.add_argument(
135 | "--use-units-meters",
136 | action="store_true",
137 | help="Use units of meters rather than pixels when interpreting the --chip-size and --chip-stride.",
138 | )
139 |
140 | parser.add_argument(
141 | "--region-of-interest",
142 | type=str,
143 | required=False,
144 | help="Only data from this spatial region will be included in the dataloader. Should be specified as minx,miny,maxx,maxy.",
145 | )
146 |
147 | parser.add_argument(
148 | "--output-resolution",
149 | type=float,
150 | required=False,
151 | help="Spatial resolution of the data in meters/pixel. If un-set, will be the resolution of the first raster data that is read.",
152 | )
153 |
154 | parser.add_argument(
155 | "--output-CRS",
156 | type=str,
157 | required=False,
158 | help="The coordinate reference system to use for the output data. If un-set, will be the CRS of the first tile found.",
159 | )
160 |
161 | parser.add_argument(
162 | "--vector-label-folder-path",
163 | type=str,
164 | required=False,
165 | help="A folder of geospatial vector files that will be used for the label. If un-set, the dataloader will not be labeled.",
166 | )
167 |
168 | parser.add_argument(
169 | "--vector-label-attribute",
170 | type=str,
171 | default="treeID",
172 | help="Attribute to read from the vector data, such as the class or instance ID. Defaults to 'treeID'.",
173 | )
174 |
175 | parser.add_argument(
176 | "--visualize-n-tiles",
177 | type=int,
178 | required=False,
179 | help="The number of randomly-sampled tiles to display.",
180 | )
181 |
182 | parser.add_argument(
183 | "--save-folder",
184 | type=str,
185 | required=False,
186 | help="Folder to save data to. Will be created if it doesn't exist.",
187 | )
188 |
189 | parser.add_argument(
190 | "--save-n-tiles",
191 | type=int,
192 | required=False,
193 | help="Number of tiles to save. Whether they are the first tiles or random is controlled by --random-sample. If unset, all tiles will be saved.",
194 | )
195 |
196 | parser.add_argument(
197 | "--random-sample",
198 | action="store_true",
199 | help="If --save-n-tiles is set, should the tiles be randomly sampled rather than taken from the beginning of the dataloader.",
200 | )
201 |
202 | parser.add_argument(
203 | "--batch-size",
204 | type=int,
205 | required=False,
206 | help="Number of images to load in a batch. Defaults to 1.",
207 | )
208 |
209 | args = parser.parse_args()
210 | return args
211 |
212 |
213 | if __name__ == "__main__":
214 | args = parse_args()
215 | tile_data(**args.__dict__)
216 |
--------------------------------------------------------------------------------
/tree_detection_framework/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/evaluation/__init__.py
--------------------------------------------------------------------------------
/tree_detection_framework/evaluation/evaluate.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import geopandas as gpd
4 | import numpy as np
5 | from scipy.optimize import linear_sum_assignment
6 | from shapely.geometry import Polygon
7 |
8 |
9 | def compute_matched_ious(
10 | ground_truth_boxes: List[Polygon], predicted_boxes: List[Polygon]
11 | ) -> List:
12 | """Compute IoUs for matched pairs of ground truth and predicted boxes.
13 | This uses the Hungarian algorithm to find the optimal assignment.
14 | Args:
15 | ground_truth_boxes (list): List of ground truth polygons.
16 | predicted_boxes (list): List of predicted polygons.
17 | Returns:
18 | list: List of IoUs for matched pairs.
19 | """
20 | if not ground_truth_boxes or not predicted_boxes:
21 | return 0.0 # Return 0 if either list is empty
22 |
23 | # Create GeoDataFrames for ground truth and predicted boxes
24 | gt_gdf = gpd.GeoDataFrame(geometry=ground_truth_boxes)
25 | gt_gdf["area_gt"] = gt_gdf.geometry.area
26 | gt_gdf["id_gt"] = gt_gdf.index
27 |
28 | pred_gdf = gpd.GeoDataFrame(geometry=predicted_boxes)
29 | pred_gdf["area_pred"] = pred_gdf.geometry.area
30 | pred_gdf["id_pred"] = pred_gdf.index
31 |
32 | # Get the intersection between the two sets of polygons
33 | intersection = gpd.overlay(gt_gdf, pred_gdf, how="intersection")
34 | intersection["iou"] = intersection.area / (
35 | intersection["area_gt"] + (intersection["area_pred"] - intersection.area)
36 | )
37 |
38 | # Create a cost matrix to store IoUs
39 | cost_matrix = np.zeros((len(gt_gdf), len(pred_gdf)))
40 | cost_matrix[intersection["id_gt"], intersection["id_pred"]] = -intersection["iou"]
41 |
42 | # Solve optimal assignment using the Hungarian algorithm
43 | gt_indices, pred_indices = linear_sum_assignment(cost_matrix)
44 |
45 | # Extract IoUs of matched pairs
46 | matched_ious = [-cost_matrix[i, j] for i, j in zip(gt_indices, pred_indices)]
47 | return matched_ious
48 |
49 |
50 | def compute_precision_recall(
51 | ious: List, num_gt: int, num_pd: int, threshold: float = 0.4
52 | ) -> tuple:
53 | """Compute precision and recall based on IoUs.
54 | Args:
55 | ious (list): List of IoUs for matched pairs.
56 | num_gt (int): Number of ground truth boxes.
57 | num_pd (int): Number of predicted boxes.
58 | threshold (float): IoU threshold for considering a match.
59 | Returns:
60 | tuple: Precision and recall values.
61 | """
62 | true_positives = (np.array(ious) > threshold).astype(np.uint8)
63 | tp = np.sum(true_positives)
64 | recall = tp / num_gt if num_gt > 0 else 0.0
65 | precision = tp / num_pd if num_pd > 0 else 0.0
66 | return precision, recall
67 |
--------------------------------------------------------------------------------
/tree_detection_framework/postprocessing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/postprocessing/__init__.py
--------------------------------------------------------------------------------
/tree_detection_framework/postprocessing/postprocessing.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List, Optional
3 |
4 | import numpy as np
5 | import pyproj
6 | from polygone_nms import nms
7 | from shapely import box
8 | from shapely.geometry import MultiPolygon, Polygon
9 | from shapely.ops import unary_union
10 |
11 | from tree_detection_framework.detection.region_detections import (
12 | RegionDetections,
13 | RegionDetectionsSet,
14 | )
15 |
16 |
17 | def single_region_NMS(
18 | detections: RegionDetections,
19 | threshold: float = 0.5,
20 | confidence_column: str = "score",
21 | min_confidence: float = 0.3,
22 | intersection_method: str = "IOU",
23 | ) -> RegionDetections:
24 | """Run non-max suppresion on predictions from a single region.
25 |
26 | Args:
27 | detections (RegionDetections):
28 | Detections from a single region to run NMS on.
29 | threshold (float, optional):
30 | The threshold for the NMS(intersection) method. Defaults to 0.5.
31 | confidence_column (str, optional):
32 | Which column in the dataframe to use as a confidence for NMS. Defaults to "score"
33 | min_confidence (float, optional):
34 | Prediction score threshold for detections to be included.
35 | intersection_method (str, optional):
36 | The method to compute intersections, one of ("IOU", "IOS", "Dice", "IOT"). Defaults to "IOU".
37 |
38 | Returns:
39 | RegionDetections:
40 | NMS-suppressed set of detections
41 | """
42 | # Extract the geodataframe for the detections
43 | detections_df = detections.get_data_frame()
44 |
45 | # Determine which detections are high enough confidence to retain
46 | # Get rows that are both high confidence and not empty
47 | not_empty_mask = ~detections_df.geometry.is_empty
48 | high_conf_not_empty_inds = np.where(
49 | (
50 | (detections_df[confidence_column] >= min_confidence) & not_empty_mask
51 | ).to_numpy()
52 | )[0]
53 |
54 | # Filter detections based on minimum confidence score
55 | detections_df = detections_df.iloc[high_conf_not_empty_inds]
56 | if detections_df.empty:
57 | # Return empty if no detections pass threshold
58 | return detections.subset_detections([])
59 |
60 | ## Get the polygons for each detection object
61 | polygons = detections_df.geometry.to_list()
62 | # Extract the score
63 | confidences = detections_df[confidence_column].to_numpy()
64 |
65 | # Put the data into the required format, list[(polygon, class, confidence)]
66 | # TODO consider adding a class, currently set to all ones
67 | input_data = list(zip(polygons, np.ones_like(confidences), confidences))
68 |
69 | # Run polygon NMS
70 | keep_inds = nms(
71 | input_data=input_data,
72 | distributed=None,
73 | nms_method="Default",
74 | intersection_method=intersection_method,
75 | threshold=threshold,
76 | )
77 |
78 | # We only performed NMS on the high-confidence detections, but we need the indices w.r.t. the
79 | # original data with all detections. Sort for convenience so data is not permuted.
80 | keep_inds_in_original = sorted(high_conf_not_empty_inds[keep_inds])
81 | # Extract the detections that were kept
82 | subset_region_detections = detections.subset_detections(keep_inds_in_original)
83 |
84 | return subset_region_detections
85 |
86 |
87 | def multi_region_NMS(
88 | detections: RegionDetectionsSet,
89 | run_per_region_NMS: bool = True,
90 | threshold: float = 0.5,
91 | confidence_column: str = "score",
92 | min_confidence: float = 0.3,
93 | intersection_method: str = "IOU",
94 | ) -> RegionDetections:
95 | """Run non-max suppresion on predictions from multiple regions.
96 |
97 | Args:
98 | detections (RegionDetectionsSet):
99 | Detections from multiple regions to run NMS on.
100 | run_per_region_NMS (bool):
101 | Should nonmax-suppression be run on each region before the regions are merged. This may
102 | lead to a speedup if there is a large amount of within-region overlap. Defaults to True.
103 | threshold (float, optional):
104 | The threshold for the NMS(intersection) method. Defaults to 0.5.
105 | confidence_column (str, optional):
106 | Which column in the dataframe to use as a confidence for NMS. Defaults to "score"
107 |
108 | min_confidence (float, optional):
109 | Prediction score threshold for detections to be included.
110 | intersection_method (str, optional):
111 | The method to compute intersections, one of ("IOU", "IOS", "Dice", "IOT"). Defaults to "IOU".
112 | Returns:
113 | RegionDetections:
114 | NMS-suppressed set of detections, merged together for the set of regions.
115 | """
116 | # Determine whether to run NMS individually on each region.
117 | if run_per_region_NMS:
118 | # Run NMS on each sub-region and then wrap this in a region detection set
119 | detections = RegionDetectionsSet(
120 | [
121 | single_region_NMS(
122 | region_detections,
123 | threshold=threshold,
124 | confidence_column=confidence_column,
125 | min_confidence=min_confidence,
126 | intersection_method=intersection_method,
127 | )
128 | for region_detections in detections.region_detections
129 | ]
130 | )
131 |
132 | # Merge the detections into a single RegionDetections
133 | merged_detections = detections.merge()
134 |
135 | # If the bounds of the individual regions were disjoint, then no NMS needs to be applied across
136 | # the different regions
137 | if detections.disjoint_bounds():
138 | logging.info("Bounds are disjoint, skipping across-region NMS")
139 | return merged_detections
140 | logging.info("Bound have overlap, running across-region NMS")
141 |
142 | # Run NMS on this merged RegionDetections
143 | NMS_suppressed_merged_detections = single_region_NMS(
144 | merged_detections,
145 | threshold=threshold,
146 | confidence_column=confidence_column,
147 | min_confidence=min_confidence,
148 | intersection_method=intersection_method,
149 | )
150 |
151 | return NMS_suppressed_merged_detections
152 |
153 |
154 | def polygon_hole_suppression(polygon: Polygon, min_area_threshold: float = 20.0):
155 | """To remove holes within a polygon
156 |
157 | Args:
158 | polygon(shapely.Polygon):
159 | A shapely polygon object
160 | min_area_threshold(float):
161 | Remove holes within the polygons that have area smaller than this value
162 |
163 | Returns:
164 | shapely.Polygon:
165 | The equivalent polygon created after suppressing the holes
166 | """
167 | list_interiors = []
168 | # Iterate through interiors list which includes the holes
169 | for interior in polygon.interiors:
170 | interior_polygon = Polygon(interior)
171 | # If area of the hole is greater than the threshold, include it in the final output
172 | if interior_polygon.area > min_area_threshold:
173 | list_interiors.append(interior)
174 |
175 | # Return a new polygon with holes suppressed
176 | return Polygon(polygon.exterior.coords, holes=list_interiors)
177 |
178 |
179 | def single_region_hole_suppression(
180 | detections: RegionDetections, min_area_threshold: float = 20.0
181 | ):
182 | """Suppress polygon holes in a RegionDetections object.
183 |
184 | Args:
185 | detections (RegionDetections):
186 | Detections from a single region that needs suppression of polygon holes.
187 | min_area_threshold(float):
188 | Remove holes within the polygons that have area smaller than this value.
189 |
190 | Returns:
191 | RegionDetections:
192 | Detections after suppressing polygon holes.
193 | """
194 | detections_df = detections.get_data_frame()
195 | modified_geometries = []
196 |
197 | for tree_crown in detections_df.geometry.to_list():
198 | # If tree_crown is a Polygon, directly do polygon hole suppression
199 | if isinstance(tree_crown, Polygon):
200 | clean_tree_crown = polygon_hole_suppression(tree_crown, min_area_threshold)
201 | # If it is a MultiPolygon, do polygon hole suppression for each polygon within it
202 | elif isinstance(tree_crown, MultiPolygon):
203 | clean_polygons = []
204 | for polygon in tree_crown.geoms:
205 | clean_polygon = polygon_hole_suppression(polygon, min_area_threshold)
206 | clean_polygons.append(clean_polygon)
207 | # Create a new MultiPolygon with the suppressed polygons
208 | clean_tree_crown = MultiPolygon(clean_polygons)
209 | # For any other cases, create an empty polygon (just to be safe)
210 | else:
211 | clean_tree_crown = Polygon()
212 |
213 | # Add the cleaned polygon/multipolygon to a list
214 | modified_geometries.append(clean_tree_crown)
215 |
216 | # Set this list as the geometry column in the dataframe
217 | detections_df.geometry = modified_geometries
218 | # Return a new RegionDetections object created using the updated dataframe
219 | # TODO: Handle cases where the data is in pixels with no transform to geospatial
220 | return RegionDetections(
221 | detection_geometries=None,
222 | data=detections_df,
223 | input_in_pixels=False,
224 | CRS=detections.get_CRS(),
225 | )
226 |
227 |
228 | def multi_region_hole_suppression(
229 | detections: RegionDetectionsSet, min_area_threshold: float = 20.0
230 | ):
231 | """Suppress polygon holes in a RegionDetectionsSet object.
232 |
233 | Args:
234 | detections (RegionDetectionsSet):
235 | Set of detections from a multiple regions that need suppression of polygon holes.
236 | min_area_threshold(float):
237 | Remove holes within the polygons that have area smaller than this value.
238 |
239 | Returns:
240 | RegionDetectionsSet:
241 | Set of detections after suppressing polygon holes.
242 | """
243 | # Perform single_region_hole_suppression for every region within the RegionDetectionsSet
244 | return RegionDetectionsSet(
245 | [
246 | single_region_hole_suppression(region_detections, min_area_threshold)
247 | for region_detections in detections.region_detections
248 | ]
249 | )
250 |
251 |
252 | def merge_and_postprocess_detections(
253 | detections: RegionDetectionsSet,
254 | tolerance: Optional[float] = 0.2,
255 | min_area_threshold: Optional[float] = 20.0,
256 | ) -> RegionDetections:
257 | """Apply postprocessing techniques that include:
258 | 1. Get a union of polygons that have been split across tiles
259 | 2. Simplify the edges of polygons by `tolerance` value
260 | 3. Remove holes within the polygons that are smaller than `min_area_threshold` value
261 | Merges regions into a single RegionDetections.
262 |
263 | Args:
264 | detections(RegionDetectionsSet):
265 | Detections from multiple regions to postprocess.
266 | tolerance (Optional[float], optional):
267 | A value that controls the simplification of the detection polygons.
268 | The higher this value, the smaller the number of vertices in the resulting geometry.
269 | min_area_threshold (Optional[float], optional):
270 | Holes within polygons having an area lesser than this value get removed.
271 |
272 | Returns:
273 | RegionDetections:
274 | Postprocessed set of detections, merged together for the set of regions.
275 | """
276 | # Get the detections as a merged GeoDataFrame
277 | all_detections_gdf = detections.get_data_frame(merge=True)
278 |
279 | # Apply a small negative buffer to shrink polygons slightly
280 | buffered_geoms = [geom.buffer(-0.001) for geom in all_detections_gdf.geometry]
281 |
282 | # Compute the union of the set of polyogns. This step removes any vertical lines caused by the tile edges
283 | # and combines a single polygon that might have been split into multiple. Also removes any overlaps.
284 | union_detections = unary_union(buffered_geoms)
285 |
286 | # Simplify the polygons by tolerance value and extract only Polygons and MultiPolygons
287 | # since `union_detections` can have Point objects as well
288 | filtered_geoms = [
289 | geom.simplify(tolerance)
290 | for geom in list(union_detections.geoms)
291 | if isinstance(geom, (Polygon, MultiPolygon))
292 | ]
293 |
294 | # To remove small holes within polygons
295 | new_polygons = []
296 | for polygon in filtered_geoms:
297 | new_polygon = polygon_hole_suppression(polygon, min_area_threshold)
298 | new_polygons.append(new_polygon)
299 |
300 | # Create a RegionDetections for the merged and postprocessed detections
301 | # TODO: Handle cases when input is in pixels
302 | postprocessed_detections = RegionDetections(
303 | new_polygons, input_in_pixels=False, CRS=all_detections_gdf.crs
304 | )
305 |
306 | return postprocessed_detections
307 |
308 |
309 | def suppress_tile_boundary_with_NMS(
310 | predictions: RegionDetectionsSet,
311 | iou_threshold: float = 0.5,
312 | ios_threshold: float = 0.5,
313 | min_confidence: float = 0.3,
314 | ) -> RegionDetections:
315 | """
316 | Used as a post-processing step with the `GeometricDetector` class to suppress detections that are split across tiles.
317 | This is done by applying NMS twice, first using IOU and then using IOS.
318 |
319 | Args:
320 | predictions (RegionDetectionsSet):
321 | Detections from multiple regions.
322 | iou_threshold (float, optional):
323 | The threshold for the NMS method that uses IoU metric. Defaults to 0.5.
324 | ios_threshold (float, optional):
325 | The threshold for the NMS method that uses IoS metric. Defaults to 0.5.
326 | min_confidence (float, optional):
327 | Prediction score threshold for detections to be included.
328 |
329 | Returns:
330 | RegionDetections:
331 | NMS postprocessed set of detections, merged together.
332 | """
333 |
334 | iou_nms = multi_region_NMS(
335 | predictions,
336 | intersection_method="IOU",
337 | threshold=iou_threshold,
338 | min_confidence=min_confidence,
339 | )
340 |
341 | iou_ios_nms = single_region_NMS(
342 | iou_nms,
343 | intersection_method="IOS",
344 | threshold=ios_threshold,
345 | min_confidence=min_confidence,
346 | )
347 |
348 | return iou_ios_nms
349 |
350 |
351 | def remove_out_of_bounds_detections(
352 | region_detection_sets: List[RegionDetectionsSet], image_bounds: List
353 | ) -> List[RegionDetectionsSet]:
354 | """
355 | Filters out detections that are outside the bounds of a defined region.
356 | Used as a post-processing step after `predict_raw_drone_images()`.
357 |
358 | Args:
359 | region_detection_sets (List[RegionDetectionSet]):
360 | Each elemet is a RegionDetectionsSet derived from a specific drone image.
361 | Length is the number of raw drone images given to the dataloader.
362 | image_bounds (List[bounding_box]):
363 | Each element is a `bounding_box` object derived from the dataloader.
364 | Length: number of regions in a set * number of RegionDetectionSet objects
365 |
366 | Returns:
367 | List of RegiondetectionSet objects with out-of-bounds predictions filtered out.
368 | """
369 |
370 | region_idx = 0 # To index elements in true_bounds
371 | list_of_filtered_region_sets = []
372 |
373 | for rds in region_detection_sets:
374 |
375 | # Find the number of regions in every set
376 | num_of_regions_in_a_set = len(rds.get_data_frame())
377 | # To save filtered regions in a particular set
378 | list_of_filtered_regions = []
379 |
380 | # Get the region image bounds for the RegionDetectionSet
381 | region_image_bounds = image_bounds[region_idx]
382 |
383 | for idx in range(num_of_regions_in_a_set):
384 |
385 | # Get RegionsDetections object from the set
386 | rd = rds.get_region_detections(idx)
387 | rd_gdf = rd.get_data_frame()
388 |
389 | # Create a polygon of size equal to the image dimensions
390 | region_set_polygon = box(
391 | region_image_bounds.minx,
392 | region_image_bounds.maxy,
393 | region_image_bounds.maxx,
394 | region_image_bounds.miny,
395 | )
396 |
397 | # TODO: Instead of removing detections partially extending out of the boundary,
398 | # try cropping it using gpd.clip()
399 | # Remove detections that extend beyond the image bounds
400 | within_bounds_indices = rd_gdf.within(region_set_polygon)
401 | within_bounds_indices_true = within_bounds_indices[
402 | within_bounds_indices
403 | ].index
404 |
405 | # Subset the RegionDetections object keeping only the valid indices calculated before
406 | filtered_rd = rd.subset_detections(within_bounds_indices_true)
407 | list_of_filtered_regions.append(filtered_rd)
408 |
409 | list_of_filtered_region_sets.append(
410 | RegionDetectionsSet(list_of_filtered_regions)
411 | )
412 |
413 | # Update region_idx to point to the image dims of the next rds
414 | region_idx += num_of_regions_in_a_set
415 |
416 | return list_of_filtered_region_sets
417 |
--------------------------------------------------------------------------------
/tree_detection_framework/preprocessing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/preprocessing/__init__.py
--------------------------------------------------------------------------------
/tree_detection_framework/preprocessing/derived_geodatasets.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, namedtuple
2 | from pathlib import Path
3 | from typing import Any, List, Optional, Union
4 |
5 | import fiona
6 | import fiona.transform
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import rasterio
10 | import shapely.geometry
11 | import torch
12 | from PIL import Image
13 | from shapely.affinity import affine_transform
14 | from torch.utils.data import DataLoader, Dataset
15 | from torchgeo.datamodules import GeoDataModule
16 | from torchgeo.datasets import (
17 | IntersectionDataset,
18 | RasterDataset,
19 | VectorDataset,
20 | stack_samples,
21 | )
22 | from torchgeo.datasets.utils import BoundingBox, array_to_tensor
23 | from torchgeo.samplers import GridGeoSampler, Units
24 | from torchvision import transforms
25 |
26 | from tree_detection_framework.constants import PATH_TYPE
27 |
28 | # Define a namedtuple to store bounds of tiles images from the `CustomImageDataset`
29 | bounding_box = namedtuple("bounding_box", ["minx", "maxx", "miny", "maxy"])
30 |
31 |
32 | class CustomRasterDataset(RasterDataset):
33 | """
34 | Custom dataset class for orthomosaic raster images. This class extends the `RasterDataset` from `torchgeo`.
35 |
36 | Attributes:
37 | filename_glob (str): Glob pattern to match files in the directory.
38 | is_image (bool): Indicates that the data being loaded is image data.
39 | separate_files (bool): True if data is stored in a separate file for each band, else False.
40 | """
41 |
42 | filename_glob: str = "*.tif" # To match all TIFF files
43 | is_image: bool = True
44 | separate_files: bool = False
45 |
46 |
47 | class CustomVectorDataset(VectorDataset):
48 | """
49 | Custom dataset class for vector data which act as labels for the raster data. This class extends the `VectorDataset` from `torchgeo`.
50 | """
51 |
52 | def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
53 | """Retrieve image/mask and metadata indexed by query.
54 | This function is largely based on the `__getitem__` method from TorchGeo's `VectorDataset`.
55 | Modifications have been made to include the following keys within the returned dictionary:
56 | 1. 'shapes' as polygons per tile represented in pixel coordinates.
57 | 2. 'bounding_boxes' as bounding box of every detected polygon per tile in pixel coordinates.
58 |
59 | Args:
60 | query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
61 |
62 | Returns:
63 | sample of image/mask and metadata at that index
64 |
65 | Raises:
66 | IndexError: if query is not found in the index
67 | """
68 | hits = self.index.intersection(tuple(query), objects=True)
69 | filepaths = [hit.object for hit in hits]
70 |
71 | if not filepaths:
72 | raise IndexError(
73 | f"query: {query} not found in index with bounds: {self.bounds}"
74 | )
75 |
76 | shapes = []
77 | for filepath in filepaths:
78 | with fiona.open(filepath) as src:
79 | # We need to know the bounding box of the query in the source CRS
80 | (minx, maxx), (miny, maxy) = fiona.transform.transform(
81 | self.crs.to_dict(),
82 | src.crs,
83 | [query.minx, query.maxx],
84 | [query.miny, query.maxy],
85 | )
86 |
87 | # Filter geometries to those that intersect with the bounding box
88 | for feature in src.filter(bbox=(minx, miny, maxx, maxy)):
89 | # Warp geometries to requested CRS
90 | shape = fiona.transform.transform_geom(
91 | src.crs, self.crs.to_dict(), feature["geometry"]
92 | )
93 | label = self.get_label(feature)
94 | shapes.append((shape, label))
95 |
96 | # Rasterize geometries
97 | width = (query.maxx - query.minx) / self.res
98 | height = (query.maxy - query.miny) / self.res
99 | transform = rasterio.transform.from_bounds(
100 | query.minx, query.miny, query.maxx, query.maxy, width, height
101 | )
102 | if shapes:
103 | masks = rasterio.features.rasterize(
104 | shapes, out_shape=(round(height), round(width)), transform=transform
105 | )
106 | else:
107 | # If no features are found in this query, return an empty mask
108 | # with the default fill value and dtype used by rasterize
109 | masks = np.zeros((round(height), round(width)), dtype=np.uint8)
110 |
111 | # Use array_to_tensor since rasterize may return uint16/uint32 arrays.
112 | masks = array_to_tensor(masks)
113 |
114 | masks = masks.to(self.dtype)
115 |
116 | # Beginning of additions made to this function
117 |
118 | # Invert the transform to convert geo coordinates to pixel values
119 | inverse_transform = ~transform
120 |
121 | # Convert `fiona` type shapes to `shapely` shape objects for easier manipulation
122 | shapely_shapes = [(shapely.geometry.shape(sh), i) for sh, i in shapes]
123 |
124 | # Apply the inverse transform to each shapely shape, converting geo coordinates to pixel coordinates
125 | pixel_transformed_shapes = [
126 | (affine_transform(sh, inverse_transform.to_shapely()), i)
127 | for sh, i in shapely_shapes
128 | ]
129 |
130 | # Convert each polygon to an axis-aligned bounding box of format (x_min, y_min, x_max, y_max) in pixel coordinates
131 | bounding_boxes = []
132 | for polygon, _ in pixel_transformed_shapes:
133 | x_min, y_min, x_max, y_max = polygon.bounds
134 | bounding_boxes.append([x_min, y_min, x_max, y_max])
135 |
136 | # Add `shapes` and `bounding_boxes` to the dictionary.
137 | sample = {
138 | "mask": masks,
139 | "crs": self.crs,
140 | "bounds": query,
141 | "shapes": pixel_transformed_shapes,
142 | "bounding_boxes": bounding_boxes,
143 | }
144 |
145 | if self.transforms is not None:
146 | sample = self.transforms(sample)
147 |
148 | return sample
149 |
150 |
151 | class CustomImageDataset(Dataset):
152 | def __init__(
153 | self,
154 | images_dir: Union[PATH_TYPE, List[str]],
155 | chip_size: int,
156 | chip_stride: int,
157 | labels_dir: Optional[List[str]] = None,
158 | ):
159 | """
160 | Dataset for creating a dataloader from a folder of individual images, with an option to create tiles.
161 |
162 | Args:
163 | images_dir (Union[Path, List[str]]): Path to the folder containing image files, or list of paths to image files.
164 | chip_size (int): Dimension of each image chip (width, height) in pixels.
165 | chip_stride (int): Stride to take while chipping the images (horizontal, vertical) in pixels.
166 | labels_dir (Optional[List[str]]): List of paths to annotation .geojson files corresponding to the images.
167 | Should have same file name as the image.
168 | """
169 | self.chip_size = chip_size
170 | self.chip_stride = chip_stride
171 |
172 | if not isinstance(images_dir, list):
173 | self.images_dir = Path(images_dir)
174 | # Get a list of all image paths
175 | image_extensions = [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"]
176 | self.image_paths = sorted(
177 | [
178 | path
179 | for path in self.images_dir.glob("*")
180 | if path.suffix.lower() in image_extensions
181 | ]
182 | )
183 |
184 | if len(self.image_paths) == 0:
185 | raise ValueError(f"No image files found in {self.images_dir}")
186 | else:
187 | self.image_paths = images_dir
188 |
189 | self.labels_paths = labels_dir
190 | self.tile_metadata = self._get_metadata()
191 |
192 | def _get_metadata(self):
193 | metadata = []
194 | for img_path in self.image_paths:
195 | # Ensure the label path corresponds to the same file name as the image
196 | label_path = None
197 | if self.labels_paths:
198 | # Match the label file by replacing the image extension with `.geojson`
199 | expected_label_name = img_path.stem + ".geojson"
200 | matching_labels = filter(
201 | lambda label: Path(label).name == expected_label_name,
202 | self.labels_paths,
203 | )
204 | label_path = next(matching_labels, None)
205 | if label_path is None:
206 | raise ValueError(f"Label file not found for image: {img_path}")
207 |
208 | tile_idx = 0 # A unique tile index value within this image, resets for every new image
209 | with Image.open(img_path) as img:
210 | img_width, img_height = img.size
211 |
212 | # Generate tile coordinates
213 | for y in range(0, img_height, self.chip_stride):
214 | for x in range(0, img_width, self.chip_stride):
215 | # Add metadata for the current tile
216 | metadata.append((tile_idx, img_path, label_path, x, y))
217 | tile_idx += 1
218 | return metadata
219 |
220 | def __len__(self):
221 | return len(self.tile_metadata)
222 |
223 | def __getitem__(self, idx):
224 |
225 | img_idx, img_path, label_path, x, y = self.tile_metadata[idx]
226 |
227 | with Image.open(img_path) as img:
228 | img = img.convert("RGB")
229 |
230 | # Check if the tile extends beyond the image boundary
231 | tile_width = min(self.chip_size, img.width - x)
232 | tile_height = min(self.chip_size, img.height - y)
233 |
234 | # If the tile fits within the image, return the cropped image
235 | if tile_width == self.chip_size and tile_height == self.chip_size:
236 | tile = img.crop((x, y, x + self.chip_size, y + self.chip_size))
237 | else:
238 | # Create a white square tile of shape 'chip_size'
239 | tile = Image.new(
240 | "RGB", (self.chip_size, self.chip_size), (255, 255, 255)
241 | )
242 |
243 | # Crop the image section and paste onto the white image
244 | img_section = img.crop((x, y, x + tile_width, y + tile_height))
245 | tile.paste(img_section, (0, 0))
246 |
247 | # Convert to tensor
248 | if not isinstance(tile, torch.Tensor):
249 | tile = transforms.ToTensor()(tile)
250 |
251 | metadata = {
252 | "image_index": img_idx,
253 | "source_image": str(img_path),
254 | "image_bounds": bounding_box(
255 | 0,
256 | float(img.width),
257 | float(img.height),
258 | 0,
259 | ),
260 | }
261 | if self.labels_paths is not None:
262 | metadata["annotations"] = str(label_path)
263 |
264 | return {
265 | "image": tile,
266 | "metadata": metadata,
267 | # Bounds includes bounding box values for the whole tile including white padded region if any
268 | "bounds": bounding_box(
269 | float(x),
270 | float(x + self.chip_size),
271 | float(y + self.chip_size),
272 | float(y),
273 | ),
274 | "crs": None,
275 | }
276 |
277 | @staticmethod
278 | def collate_as_defaultdict(batch):
279 | # Stack images from batch into a single tensor
280 | images = torch.stack([item["image"] for item in batch])
281 | # Collect metadata as a list
282 | metadata = [item["metadata"] for item in batch]
283 | bounds = [item["bounds"] for item in batch]
284 | crs = [item["crs"] for item in batch]
285 | return defaultdict(
286 | lambda: None,
287 | {"image": images, "metadata": metadata, "bounds": bounds, "crs": crs},
288 | )
289 |
290 |
291 | class CustomDataModule(GeoDataModule):
292 | # TODO: Add docstring
293 | def __init__(
294 | self,
295 | output_res: float,
296 | train_raster_path: str,
297 | vector_label_name: str,
298 | train_vector_path: str,
299 | size: int,
300 | stride: int,
301 | batch_size: int = 2,
302 | val_raster_path: Optional[str] = None,
303 | val_vector_path: Optional[str] = None,
304 | test_raster_path: Optional[str] = None,
305 | test_vector_path: Optional[str] = None,
306 | ) -> None:
307 | super().__init__(dataset_class=IntersectionDataset)
308 | self.output_res = output_res
309 | self.vector_label_name = vector_label_name
310 | self.size = size
311 | self.stride = stride
312 | self.batch_size = batch_size
313 |
314 | # Paths for train, val and test dataset
315 | self.train_raster_path = train_raster_path
316 | self.val_raster_path = val_raster_path
317 | self.test_raster_path = test_raster_path
318 | self.train_vector_path = train_vector_path
319 | self.val_vector_path = val_vector_path
320 | self.test_vector_path = test_vector_path
321 |
322 | def create_intersection_dataset(
323 | self, raster_path: str, vector_path: str
324 | ) -> IntersectionDataset:
325 | raster_data = CustomRasterDataset(paths=raster_path, res=self.output_res)
326 | vector_data = CustomVectorDataset(
327 | paths=vector_path, res=self.output_res, label_name=self.vector_label_name
328 | )
329 | return raster_data & vector_data # IntersectionDataset
330 |
331 | def setup(self, stage=None):
332 | # create the data based on the stage the Trainer is in
333 | if stage == "fit":
334 | self.train_data = self.create_intersection_dataset(
335 | self.train_raster_path, self.train_vector_path
336 | )
337 | if stage == "validate" or stage == "fit":
338 | self.val_data = self.create_intersection_dataset(
339 | self.val_raster_path, self.val_vector_path
340 | )
341 | if stage == "test":
342 | self.test_data = self.create_intersection_dataset(
343 | self.test_raster_path, self.test_vector_path
344 | )
345 |
346 | def train_dataloader(self) -> DataLoader:
347 | sampler = GridGeoSampler(self.train_data, size=self.size, stride=self.stride)
348 | return DataLoader(
349 | self.train_data,
350 | sampler=sampler,
351 | collate_fn=stack_samples,
352 | batch_size=self.batch_size,
353 | )
354 |
355 | def val_dataloader(self) -> DataLoader:
356 | sampler = GridGeoSampler(
357 | self.val_data, size=self.size, stride=self.stride, units=Units.CRS
358 | )
359 | return DataLoader(
360 | self.val_data,
361 | sampler=sampler,
362 | collate_fn=stack_samples,
363 | batch_size=self.batch_size,
364 | )
365 |
366 | def test_dataloader(self) -> DataLoader:
367 | sampler = GridGeoSampler(
368 | self.test_data, size=self.size, stride=self.stride, units=Units.CRS
369 | )
370 | return DataLoader(
371 | self.test_data,
372 | sampler=sampler,
373 | collate_fn=stack_samples,
374 | batch_size=self.batch_size,
375 | )
376 |
377 | def on_after_batch_transfer(self, batch, dataloader_idx: int):
378 | return batch
379 |
--------------------------------------------------------------------------------
/tree_detection_framework/preprocessing/preprocessing.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import random
4 | from pathlib import Path
5 | from typing import List, Optional, Union
6 |
7 | import matplotlib.pyplot as plt
8 | import pyproj
9 | import rasterio
10 | import shapely
11 | import torch
12 | from torch.utils.data import DataLoader
13 | from torchgeo.datasets import IntersectionDataset, stack_samples, unbind_samples
14 | from torchgeo.samplers import GridGeoSampler, Units
15 | from torchvision.transforms import ToPILImage
16 |
17 | from tree_detection_framework.constants import ARRAY_TYPE, BOUNDARY_TYPE, PATH_TYPE
18 | from tree_detection_framework.preprocessing.derived_geodatasets import (
19 | CustomImageDataset,
20 | CustomRasterDataset,
21 | CustomVectorDataset,
22 | )
23 | from tree_detection_framework.utils.geospatial import get_projected_CRS
24 | from tree_detection_framework.utils.raster import plot_from_dataloader
25 |
26 | logging.basicConfig(
27 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
28 | )
29 |
30 |
31 | def create_spatial_split(
32 | region_to_be_split: BOUNDARY_TYPE, split_fractions: ARRAY_TYPE
33 | ) -> List[shapely.MultiPolygon]:
34 | """Creates non-overlapping spatial splits
35 |
36 | Args:
37 | region_to_be_split (BOUNDARY_TYPE):
38 | A spatial region to be split up. May be defined as a shapely object, geopandas object,
39 | or a path to a geospatial file. In any case, the union of all the elements will be taken.
40 | split_fractions (ARRAY_TYPE):
41 | A sequence of fractions to split the input region into. If they don't sum to 1, the total
42 | wlil be normalized.
43 |
44 | Returns:
45 | List[shapely.MultiPolygon]:
46 | A list of regions representing spatial splits of the input. The area of each one is
47 | controlled by the corresponding element in split_fractions.
48 |
49 | """
50 | raise NotImplementedError()
51 |
52 |
53 | def create_dataloader(
54 | raster_folder_path: PATH_TYPE,
55 | chip_size: float,
56 | chip_stride: Optional[float] = None,
57 | chip_overlap_percentage: float = None,
58 | use_units_meters: bool = False,
59 | region_of_interest: Optional[BOUNDARY_TYPE] = None,
60 | output_resolution: Optional[float] = None,
61 | output_CRS: Optional[pyproj.CRS] = None,
62 | vector_label_folder_path: Optional[PATH_TYPE] = None,
63 | vector_label_attribute: Optional[str] = None,
64 | batch_size: int = 1,
65 | ) -> DataLoader:
66 | """
67 | Create a tiled dataloader using torchgeo. Contains raster data data and optionally vector labels
68 |
69 | Args:
70 | raster_folder_path (PATH_TYPE): Path to the folder or raster files
71 | chip_size (float):
72 | Dimension of the chip. May be pixels or meters, based on `use_units_meters`.
73 | chip_stride (Optional[float], optional):
74 | Stride of the chip. May be pixels or meters, based on `use_units_meters`. If used,
75 | `chip_overlap_percentage` should not be set. Defaults to None.
76 | chip_overlap_percentage (Optional[float], optional):
77 | Percent overlap of the chip from 0-100. If used, `chip_stride` should not be set.
78 | Defaults to None.
79 | use_units_meters (bool, optional):
80 | Use units of meters rather than pixels when interpreting the `chip_size` and `chip_stride`.
81 | Defaults to False.
82 | region_of_interest (Optional[BOUNDARY_TYPE], optional):
83 | Only data from this spatial region will be included in the dataloader. Defaults to None.
84 | output_resolution (Optional[float], optional):
85 | Spatial resolution the data in meters/pixel. If un-set, will be the resolution of the
86 | first raster data that is read. Defaults to None.
87 | output_CRS: (Optional[pyproj.CRS], optional):
88 | The coordinate reference system to use for the output data. If un-set, will be the CRS
89 | of the first tile found. Defaults to None.
90 | vector_label_folder_path (Optional[PATH_TYPE], optional):
91 | A folder of geospatial vector files that will be used for the label. If un-set, the
92 | dataloader will not be labeled. Defaults to None.
93 | vector_label_attribute (Optional[str], optional):
94 | Attribute to read from the vector data, such as the class or instance ID. Defaults to None.
95 | batch_size (int, optional):
96 | Number of images to load in a batch. Defaults to 1.
97 |
98 | Returns:
99 | DataLoader:
100 | A dataloader containing tiles from the raster data and optionally corresponding labels
101 | from the vector data.
102 | """
103 |
104 | # changes: 1. bounding box included in every sample as a df / np array
105 | # 2. TODO: float or uint8 images
106 | # match with the param dict from the model, else error out
107 | # Stores image data
108 | raster_dataset = CustomRasterDataset(
109 | paths=raster_folder_path, res=output_resolution
110 | )
111 |
112 | # Stores label data
113 | vector_dataset = (
114 | CustomVectorDataset(
115 | paths=vector_label_folder_path,
116 | res=output_resolution,
117 | label_name=vector_label_attribute,
118 | )
119 | if vector_label_folder_path is not None
120 | else None
121 | )
122 |
123 | units = Units.CRS if use_units_meters == True else Units.PIXELS
124 | logging.info(f"Units = {units}")
125 |
126 | if use_units_meters and raster_dataset.crs.is_geographic:
127 | # Reproject the dataset to a meters-based CRS
128 | logging.info("Projecting to meters-based CRS...")
129 | lat, lon = raster_dataset.bounds[2], raster_dataset.bounds[0]
130 |
131 | # Return a new projected CRS value with meters units
132 | projected_crs = get_projected_CRS(lat, lon)
133 |
134 | # Type conversion to rasterio.crs
135 | projected_crs = rasterio.crs.CRS.from_wkt(projected_crs.to_wkt())
136 |
137 | # Recreating the raster and vector dataset objects with the new CRS value
138 | raster_dataset = CustomRasterDataset(
139 | paths=raster_folder_path, crs=projected_crs
140 | )
141 | vector_dataset = (
142 | CustomVectorDataset(
143 | paths=vector_label_folder_path,
144 | crs=projected_crs,
145 | label_name=vector_label_attribute,
146 | )
147 | if vector_label_folder_path is not None
148 | else None
149 | )
150 |
151 | # Create an intersection dataset that combines raster and label data if given. Otherwise, proceed with just raster_dataset.
152 | final_dataset = (
153 | IntersectionDataset(raster_dataset, vector_dataset)
154 | if vector_label_folder_path is not None
155 | else raster_dataset
156 | )
157 |
158 | if chip_overlap_percentage:
159 | # Calculate `chip_stride` if `chip_overlap_percentage` is provided
160 | chip_stride = chip_size * (1 - chip_overlap_percentage / 100.0)
161 | logging.info(f"Calculated stride based on overlap: {chip_stride}")
162 |
163 | elif chip_stride is None:
164 | raise ValueError(
165 | "Either 'chip_stride' or 'chip_overlap_percentage' must be provided."
166 | )
167 |
168 | logging.info(f"Stride = {chip_stride}")
169 |
170 | # GridGeoSampler to get contiguous tiles
171 | sampler = GridGeoSampler(
172 | final_dataset, size=chip_size, stride=chip_stride, units=units
173 | )
174 | dataloader = DataLoader(
175 | final_dataset, batch_size=batch_size, sampler=sampler, collate_fn=stack_samples
176 | )
177 |
178 | return dataloader
179 |
180 |
181 | def create_image_dataloader(
182 | images_dir: Union[PATH_TYPE, List[str]],
183 | chip_size: int,
184 | chip_stride: Optional[int] = None,
185 | chip_overlap_percentage: Optional[float] = None,
186 | labels_dir: Optional[List[str]] = None,
187 | batch_size: int = 1,
188 | ) -> DataLoader:
189 | """
190 | Create a dataloader for a folder of normal images (e.g., JPGs), tiling them into smaller patches.
191 |
192 | Args:
193 | images_dir (Union[Path, List[str]]):
194 | Path to the folder containing image files, or list of paths to image files.
195 | chip_size (int):
196 | Size of the tiles (width, height) in pixels.
197 | chip_stride (Optional[int], optional):
198 | Stride of the tiling (horizontal, vertical) in pixels.
199 | chip_overlap_percentage (Optional[float], optional):
200 | Percent overlap of the chip from 0-100. If used, `chip_stride` should not be set.
201 | labels_dir (Optional[List[str]], optional):
202 | List of paths to tree crown label files corresponding to the images.
203 | This will be used as ground truth during evaluation
204 | batch_size (int, optional):
205 | Number of tiles in a batch. Defaults to 1.
206 |
207 | Returns:
208 | DataLoader: A dataloader containing the tiles and associated metadata.
209 | """
210 |
211 | logging.info("Units set in PIXELS")
212 |
213 | if chip_overlap_percentage:
214 | # Calculate `chip_stride` if `chip_overlap_percentage` is provided
215 | chip_stride = chip_size * (1 - chip_overlap_percentage / 100.0)
216 | chip_stride = int(chip_stride)
217 | logging.info(f"Calculated stride based on overlap: {chip_stride}")
218 |
219 | elif chip_stride is None:
220 | raise ValueError(
221 | "Either 'chip_stride' or 'chip_overlap_percentage' must be provided."
222 | )
223 |
224 | dataset = CustomImageDataset(
225 | images_dir=images_dir,
226 | chip_size=chip_size,
227 | chip_stride=chip_stride,
228 | labels_dir=labels_dir,
229 | )
230 | dataloader = DataLoader(
231 | dataset,
232 | batch_size=batch_size,
233 | shuffle=False,
234 | collate_fn=CustomImageDataset.collate_as_defaultdict,
235 | )
236 | return dataloader
237 |
238 |
239 | def visualize_dataloader(dataloader: DataLoader, n_tiles: int):
240 | """Show samples from the dataloader.
241 |
242 | Args:
243 | dataloader (DataLoader): The dataloader to visualize.
244 | n_tiles (int): The number of randomly-sampled tiles to show.
245 | """
246 | # Get a random sample of `n_tiles` index values to visualize
247 | tile_indices = random.sample(range(len(dataloader.sampler)), n_tiles)
248 |
249 | # Get a list of all tile bounds from the sampler
250 | list_of_bboxes = list(dataloader.sampler)
251 |
252 | for i in tile_indices:
253 | sample_bbox = list_of_bboxes[i]
254 |
255 | # Get the referenced sample from the dataloader
256 | sample = dataloader.dataset[sample_bbox]
257 |
258 | # Plot the sample image.
259 | plot_from_dataloader(sample)
260 | plt.axis("off")
261 | plt.show()
262 |
263 |
264 | def save_dataloader_contents(
265 | dataloader: DataLoader,
266 | save_folder: PATH_TYPE,
267 | n_tiles: Optional[int] = None,
268 | random_sample: bool = False,
269 | ):
270 | """Save contents of the dataloader to a folder.
271 |
272 | Args:
273 | dataloader (DataLoader):
274 | Dataloader to save the contents of.
275 | save_folder (PATH_TYPE):
276 | Folder to save data to. Will be created if it doesn't exist.
277 | n_tiles (Optional[int], optional):
278 | Number of tiles to save. Whether they are the first tiles or random is controlled by
279 | `random_sample`. If unset, all tiles will be saved. Defaults to None.
280 | random_sample (bool, optional):
281 | If `n_tiles` is set, should the tiles be randomly sampled rather than taken from the
282 | beginning of the dataloader. Defaults to False.
283 | """
284 | # Create save directory if it doesn't exist
285 | destination_folder = Path(save_folder)
286 | destination_folder.mkdir(parents=True, exist_ok=True)
287 |
288 | transform_to_pil = ToPILImage()
289 |
290 | # TODO: handle batch_size > 1
291 | # Collect all batches from the dataloader
292 | all_batches = list(dataloader)
293 |
294 | # Flatten the list of batches into individual samples
295 | all_samples = [sample for batch in all_batches for sample in unbind_samples(batch)]
296 |
297 | # If `n_tiles` is set, limit the number of tiles to save
298 | if n_tiles is not None:
299 | if random_sample:
300 | # Randomly sample `n_tiles`. If `n_tiles` is greater than available samples, include all samples.
301 | selected_samples = random.sample(
302 | all_samples, min(n_tiles, len(all_samples))
303 | )
304 | else:
305 | # Take first `n_tiles`
306 | selected_samples = all_samples[:n_tiles]
307 | else:
308 | selected_samples = all_samples
309 |
310 | # Counter for saved tiles
311 | saved_tiles_count = 0
312 |
313 | # Iterate over the selected samples
314 | for sample in selected_samples:
315 | image = sample["image"]
316 | image_tensor = torch.clamp(image / 255.0, min=0, max=1)
317 | pil_image = transform_to_pil(image_tensor)
318 |
319 | # Save the image tile
320 | pil_image.save(destination_folder / f"tile_{saved_tiles_count}.png")
321 |
322 | # Prepare tile metadata
323 | metadata = {
324 | "crs": sample["crs"].to_string(),
325 | "bounds": list(sample["bounds"]),
326 | }
327 |
328 | # If dataset includes labels, save crown metadata
329 | if isinstance(dataloader.dataset, IntersectionDataset):
330 | shapes = sample["shapes"]
331 | crowns = [
332 | {"ID": tree_id, "crown": polygon.wkt} for polygon, tree_id in shapes
333 | ]
334 | metadata["crowns"] = crowns
335 |
336 | # Save metadata to a JSON file
337 | with open(destination_folder / f"tile_{saved_tiles_count}.json", "w") as f:
338 | json.dump(metadata, f, indent=4)
339 |
340 | # Increment the saved tile count
341 | saved_tiles_count += 1
342 |
343 | # Stop once the desired number of tiles is saved
344 | if n_tiles is not None and saved_tiles_count >= n_tiles:
345 | break
346 |
347 | print(f"Saved {saved_tiles_count} tiles to {save_folder}")
348 |
--------------------------------------------------------------------------------
/tree_detection_framework/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-forest-observatory/tree-detection-framework/661fded8c3cd3ac7585ec9bb5d4606778bca7494/tree_detection_framework/utils/__init__.py
--------------------------------------------------------------------------------
/tree_detection_framework/utils/benchmarking.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import xml.etree.ElementTree as ET
3 | from glob import glob
4 | from pathlib import Path
5 | from typing import List, Union
6 |
7 | import geopandas as gpd
8 | import numpy as np
9 | from shapely.geometry import box
10 | from torch.utils.data import DataLoader
11 |
12 | from tree_detection_framework.constants import PATH_TYPE
13 | from tree_detection_framework.detection.detector import Detector
14 | from tree_detection_framework.evaluation.evaluate import (
15 | compute_matched_ious,
16 | compute_precision_recall,
17 | )
18 | from tree_detection_framework.postprocessing.postprocessing import single_region_NMS
19 | from tree_detection_framework.preprocessing.preprocessing import create_image_dataloader
20 |
21 | logging.basicConfig(level=logging.INFO)
22 |
23 |
24 | def get_neon_gt(
25 | images_dir: PATH_TYPE, annotations_dir: PATH_TYPE
26 | ) -> dict[str, dict[str, List[box]]]:
27 | """
28 | Extract ground truth bounding boxes from NEON XML annotations.
29 | Args:
30 | images_dir (PATH_TYPE): Directory containing image tiles.
31 | annotations_dir (PATH_TYPE): Directory containing XML annotation files.
32 | Returns:
33 | dict: A dictionary mapping image paths to a dictionary with "gt" key containing ground truth boxes.
34 | """
35 | tiles_to_predict = list(Path(images_dir).glob("*.tif"))
36 | mappings = {}
37 |
38 | for path in tiles_to_predict:
39 | plot_name = path.stem # Get filename without extension
40 | annot_fname = Path(annotations_dir) / f"{plot_name}.xml"
41 |
42 | if not annot_fname.exists():
43 | continue
44 |
45 | # Load XML file
46 | tree = ET.parse(annot_fname)
47 | root = tree.getroot()
48 |
49 | # Extract bounding boxes
50 | gt_boxes = []
51 | for obj in root.findall(".//object"):
52 | bndbox = obj.find("bndbox")
53 | if bndbox is not None:
54 | xmin = int(bndbox.find("xmin").text)
55 | ymin = int(bndbox.find("ymin").text)
56 | xmax = int(bndbox.find("xmax").text)
57 | ymax = int(bndbox.find("ymax").text)
58 | gt_boxes.append(box(xmin, ymin, xmax, ymax))
59 |
60 | # Add the ground truth boxes to the mappings
61 | mappings[str(path)] = {"gt": gt_boxes}
62 | return mappings
63 |
64 |
65 | def get_detectree2_gt(dataloader, tile_size=1000) -> dict[str, dict[str, List[box]]]:
66 | """Extract ground truth bounding boxes from Detectree2 annotations."""
67 | mappings = {}
68 | for i in dataloader:
69 | img_path = i["metadata"][0]["source_image"]
70 | gt_gdf = gpd.read_file(i["metadata"][0]["annotations"])
71 | # Perform a vertical flip of the annotations to match the expected format. This is due to
72 | # differences between the i,j and x,y indexing conventions.
73 | gt_gdf.geometry = gt_gdf.transform(
74 | lambda x: np.concatenate((x[:, 0:1], tile_size - x[:, 1:2]), axis=1)
75 | )
76 |
77 | # Convert each geometry to its axis-aligned bounding box polygon
78 | bounding_boxes = [box(*geom.bounds) for geom in gt_gdf.geometry]
79 |
80 | mappings[img_path] = {"gt": bounding_boxes}
81 | return mappings
82 |
83 |
84 | def get_neon_dataloader(image_paths: List[str]) -> DataLoader:
85 | """Create a dataloader for the NEON dataset."""
86 | # Create dataloader setting image size as 400x400, the standard size of the NEON dataset.
87 | dataloader = create_image_dataloader(
88 | image_paths,
89 | chip_size=400,
90 | chip_stride=400,
91 | )
92 | return dataloader
93 |
94 |
95 | def get_detectree2_dataloader(
96 | images_dir: Union[PATH_TYPE, List[PATH_TYPE]],
97 | annotations_dir: Union[PATH_TYPE, List[PATH_TYPE]],
98 | ) -> "DataLoader":
99 | """Create a Detectree2 dataloader from one or more image/annotation directories."""
100 |
101 | # Extract the list of paths to image and annotation directories
102 | images_dirs = [
103 | Path(p) for p in (images_dir if isinstance(images_dir, list) else [images_dir])
104 | ]
105 | annotations_dirs = [
106 | Path(p)
107 | for p in (
108 | annotations_dir if isinstance(annotations_dir, list) else [annotations_dir]
109 | )
110 | ]
111 |
112 | # Collect all image and annotation paths
113 | img_paths = []
114 | for img_dir in images_dirs:
115 | img_paths.extend(list(img_dir.glob("*")))
116 |
117 | ann_paths = []
118 | for ann_dir in annotations_dirs:
119 | ann_paths.extend(list(ann_dir.glob("*")))
120 |
121 | # Create dataloader using all images
122 | dataloader = create_image_dataloader(
123 | images_dir=img_paths, chip_size=1000, chip_stride=1000, labels_dir=ann_paths
124 | )
125 | return dataloader
126 |
127 |
128 | def get_benchmark_detections(
129 | dataset_name: str,
130 | images_dir: Union[PATH_TYPE, List[PATH_TYPE]],
131 | annotations_dir: Union[PATH_TYPE, List[PATH_TYPE]],
132 | detectors: dict[str, Detector],
133 | nms_threshold: float = None,
134 | min_confidence: float = 0.5,
135 | nms_on_polygons: bool = False,
136 | ) -> dict[str, dict[str, List[box]]]:
137 | """
138 | Load ground truth, create dataloader, and run detectors on the images from the benchmark dataset.
139 | Args:
140 | dataset_name (str): Name of the dataset ("neon" or "detectree2").
141 | images_dir (PATH_TYPE, List[PATH_TYPE]): Directory or list of directories to image files.
142 | annotations_dir (PATH_TYPE, List[PATH_TYPE]): Directory or list of directories to annotation files.
143 | detectors (dict): Dictionary of detector instances to be evaluated.
144 | nms_threshold (float): Non-maximum suppression threshold.
145 | min_confidence (float): Minimum confidence threshold for detections.
146 | nms_on_polygons (bool): Should NMS be run on polygons before converting them to bounding boxes. Defaults to False.
147 | Returns:
148 | dict: A dictionary mapping image paths to a dictionary with detector names and the corresponding output boxes.
149 | """
150 | if dataset_name == "neon":
151 | mappings = get_neon_gt(images_dir, annotations_dir)
152 | dataloader = get_neon_dataloader(list(mappings.keys()))
153 |
154 | elif dataset_name == "detectree2":
155 | dataloader = get_detectree2_dataloader(images_dir, annotations_dir)
156 | mappings = get_detectree2_gt(dataloader)
157 |
158 | else:
159 | raise ValueError(f"Unknown dataset: {dataset_name}")
160 |
161 | for name, detector in detectors.items():
162 | # Get predictions from every detector
163 | logging.info(f"Running detector: {name}")
164 | region_detection_sets, filenames, _ = detector.predict_raw_drone_images(
165 | dataloader
166 | )
167 |
168 | # Add predictions to the mappings so that it looks like:
169 | # {"image_path_1": {"gt": gt_boxes, "detector_name_1": [boxes], ...},
170 | # "image_path_2": {"gt": gt_boxes, "detector_name_1": [boxes], ...}, ...}
171 | for filename, rds in zip(filenames, region_detection_sets):
172 | if nms_threshold is not None:
173 |
174 | # If we don't want to do polygon NMS and if the detection was originally a polygon,
175 | # instead use the data in the "bbox" column which specifies an axis-aligned bounding box.
176 | if not nms_on_polygons and name in ["detectree2", "sam2"]:
177 | rds = rds.update_geometry_column("bbox")
178 |
179 | rds = single_region_NMS(
180 | rds.get_region_detections(0),
181 | threshold=nms_threshold,
182 | min_confidence=min_confidence,
183 | )
184 | gdf = rds.get_data_frame()
185 |
186 | # Add the detections to the mappings dictionary
187 | if name == "deepforest":
188 | mappings[filename][name] = list(gdf.geometry)
189 | elif name in ["sam2", "detectree2"]:
190 | mappings[filename][name] = list(gdf["bbox"])
191 | else:
192 | raise ValueError(f"Unknown detector: {name}")
193 |
194 | return mappings
195 |
196 |
197 | def evaluate_detections(detections_dict: dict[str, dict[str, List[box]]]):
198 | """Step 2: Compute precision and recall for each detector.
199 | Args:
200 | detections_dict (dict): Dictionary mapping image paths to a dictionary
201 | with detector names and the corresponding output boxes. Output of get_neon_detections.
202 | """
203 | img_paths = list(detections_dict.keys())
204 | # Get the list of detectors, which are keys of the sub-dictionary.
205 | detector_names = [
206 | key for key in detections_dict[img_paths[0]].keys() if key != "gt"
207 | ]
208 | logging.info(f"Detectors to be evaluated: {detector_names}")
209 | for detector in detector_names:
210 | all_predictions_P = []
211 | all_predictions_R = []
212 | for img in img_paths:
213 | gt_boxes = detections_dict[img]["gt"]
214 | pred_boxes = detections_dict[img][detector]
215 | iou_output = compute_matched_ious(gt_boxes, pred_boxes)
216 | P, R = compute_precision_recall(iou_output, len(gt_boxes), len(pred_boxes))
217 | all_predictions_P.append(P)
218 | all_predictions_R.append(R)
219 |
220 | P = np.mean(all_predictions_P)
221 | R = np.mean(all_predictions_R)
222 | F1 = (2 * P * R) / (P + R)
223 | print(f"'{detector}': Precision={P}, Recall={R}, F1-Score={F1}")
224 |
--------------------------------------------------------------------------------
/tree_detection_framework/utils/detection.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import urllib
4 |
5 | import pandas as pd
6 | from deepforest.utilities import DownloadProgressBar
7 |
8 |
9 | def use_release_df(
10 | save_dir=os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../../data/"),
11 | prebuilt_model="NEON",
12 | check_release=True,
13 | ):
14 | """
15 | Check the existence of, or download the latest model release from github
16 | Args:
17 | save_dir: Directory to save filepath, default to "data" in deepforest repo
18 | prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk.
19 | check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise.
20 |
21 | Returns: release_tag, output_path (str): path to downloaded model
22 |
23 | """
24 | os.makedirs(save_dir, exist_ok=True)
25 |
26 | # Naming based on pre-built model
27 | output_path = os.path.join(save_dir, prebuilt_model + ".pt")
28 |
29 | if check_release:
30 | # Find latest github tag release from the DeepLidar repo
31 | _json = json.loads(
32 | urllib.request.urlopen(
33 | urllib.request.Request(
34 | "https://api.github.com/repos/Weecology/DeepForest/releases/latest",
35 | headers={"Accept": "application/vnd.github.v3+json"},
36 | )
37 | ).read()
38 | )
39 | asset = _json["assets"][0]
40 | url = asset["browser_download_url"]
41 |
42 | # Check the release tagged locally
43 | try:
44 | release_txt = pd.read_csv(save_dir + "current_release.csv")
45 | except BaseException:
46 | release_txt = pd.DataFrame({"current_release": [None]})
47 |
48 | # Download the current release it doesn't exist
49 | if not release_txt.current_release[0] == _json["html_url"]:
50 |
51 | print(
52 | "Downloading model from DeepForest release {}, see {} "
53 | "for details".format(_json["tag_name"], _json["html_url"])
54 | )
55 |
56 | with DownloadProgressBar(
57 | unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
58 | ) as t:
59 | urllib.request.urlretrieve(
60 | url, filename=output_path, reporthook=t.update_to
61 | )
62 |
63 | print("Model was downloaded and saved to {}".format(output_path))
64 |
65 | # record the release tag locally
66 | release_txt = pd.DataFrame({"current_release": [_json["html_url"]]})
67 | release_txt.to_csv(save_dir + "current_release.csv")
68 | else:
69 | print(
70 | "Model from DeepForest release {} was already downloaded. "
71 | "Loading model from file.".format(_json["html_url"])
72 | )
73 |
74 | return _json["html_url"], output_path
75 | else:
76 | try:
77 | release_txt = pd.read_csv(save_dir + "current_release.csv")
78 | except BaseException:
79 | raise ValueError(
80 | "Check release argument is {}, but no release "
81 | "has been previously downloaded".format(check_release)
82 | )
83 |
84 | return release_txt.current_release[0], output_path
85 |
--------------------------------------------------------------------------------
/tree_detection_framework/utils/geometric.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import shapely
4 | from contourpy import contour_generator
5 |
6 |
7 | def get_shapely_transform_from_matrix(matrix_transform: np.ndarray):
8 | """
9 | Take a matrix transform and convert it into format expected by shapely: [a, b, d, e, xoff, y_off]
10 |
11 | Args:
12 | matrix_transform (np.ndarray):
13 | (2, 3) or (3, 3) 2D transformation matrix such that the matrix-vector product produces
14 | the transformed value.
15 | """
16 | shapely_transform = [
17 | matrix_transform[0, 0],
18 | matrix_transform[0, 1],
19 | matrix_transform[1, 0],
20 | matrix_transform[1, 1],
21 | matrix_transform[0, 2],
22 | matrix_transform[1, 2],
23 | ]
24 | return shapely_transform
25 |
26 |
27 | def mask_to_shapely(
28 | mask: np.ndarray, simplify_tolerance: float = 0, backend: str = "contourpy"
29 | ) -> shapely.MultiPolygon:
30 | """
31 | Convert a binary mask to a Shapely MultiPolygon representing positive regions,
32 | with optional simplification.
33 |
34 | Args:
35 | mask (np.ndarray): A (n, m) A mask with boolean values.
36 | simplify_tolerance (float): Tolerance for simplifying polygons. A value of 0 means no simplification.
37 | backend (str): The backend to use for contour extraction. Choose from "cv2" and "contourpy". Defaults to contourpy.
38 |
39 | Returns:
40 | shapely.MultiPolygon: A MultiPolygon representing the positive regions.
41 | """
42 | if not np.any(mask):
43 | return shapely.Polygon() # Return an empty Polygon if the mask is empty.
44 |
45 | if backend == "cv2":
46 | # CV2-based approach
47 | contours, _ = cv2.findContours(
48 | mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
49 | )
50 |
51 | polygons = []
52 | for contour in contours:
53 | contour = np.squeeze(contour)
54 | # Skip invalid contours
55 | if (contour.ndim != 2) or (contour.shape[0] < 3):
56 | continue
57 |
58 | # Convert the contour to a shapely geometry
59 | shape = shapely.Polygon(contour)
60 |
61 | if isinstance(shape, shapely.MultiPolygon):
62 | # Append all individual polygons
63 | polygons.extend(shape.geoms)
64 | elif isinstance(shape, shapely.Polygon):
65 | # Append the polygon
66 | polygons.append(shape)
67 |
68 | # Combine all polygons into a MultiPolygon
69 | multipolygon = shapely.MultiPolygon(polygons)
70 |
71 | if simplify_tolerance > 0:
72 | multipolygon = multipolygon.simplify(simplify_tolerance)
73 |
74 | return multipolygon
75 |
76 | elif backend == "contourpy":
77 | # ContourPy-based approach
78 | filled = contour_generator(
79 | z=mask, fill_type="ChunkCombinedOffsetOffset"
80 | ).filled(0.5, np.inf)
81 | chunk_polygons = [
82 | shapely.from_ragged_array(
83 | shapely.GeometryType.POLYGON, points, (offsets, outer_offsets)
84 | )
85 | for points, offsets, outer_offsets in zip(*filled)
86 | ]
87 |
88 | multipolygon = shapely.unary_union(chunk_polygons)
89 |
90 | # Simplify the resulting MultiPolygon if needed
91 | if simplify_tolerance > 0:
92 | multipolygon = multipolygon.simplify(simplify_tolerance)
93 |
94 | return multipolygon
95 |
96 | else:
97 | raise ValueError(
98 | f"Unsupported backend: {backend}. Choose 'cv2' or 'contourpy'."
99 | )
100 |
--------------------------------------------------------------------------------
/tree_detection_framework/utils/geospatial.py:
--------------------------------------------------------------------------------
1 | import pyproj
2 |
3 |
4 | def get_projected_CRS(
5 | lat: float, lon: float, assume_western_hem: bool = True
6 | ) -> pyproj.CRS:
7 | """
8 | Returns a projected Coordinate Reference System (CRS) based on latitude and longitude.
9 |
10 | Args:
11 | lat (float): Latitude in degrees.
12 | lon (float): Longitude in degrees.
13 | assume_western_hem (bool): Assumes the longitude is in the Western Hemisphere. Defaults to True.
14 |
15 | Returns:
16 | pyproj.CRS: The projected CRS corresponding to the UTM zone for the given latitude and longitude.
17 | """
18 |
19 | if assume_western_hem and lon > 0:
20 | lon = -lon
21 | epgs_code = 32700 - round((45 + lat) / 90) * 100 + round((183 + lon) / 6)
22 | crs = pyproj.CRS.from_epsg(epgs_code)
23 | return crs
24 |
--------------------------------------------------------------------------------
/tree_detection_framework/utils/raster.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | from typing import Optional
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import pyproj
7 | import rasterio
8 | import rasterio as rio
9 | import rasterio.plot
10 | from rasterio.warp import Resampling, calculate_default_transform, reproject
11 |
12 | from tree_detection_framework.constants import PATH_TYPE
13 |
14 |
15 | # Copied from https://github.com/open-forest-observatory/geograypher/blob/2900ede9a00ac8bdce22c43e4abb6d74876390f6/geograypher/utils/geospatial.py#L333
16 | def load_downsampled_raster_data(dataset_filename: PATH_TYPE, downsample_factor: float):
17 | """Load a raster file spatially downsampled
18 |
19 | Args:
20 | dataset (PATH_TYPE): Path to the raster
21 | downsample_factor (float): Downsample factor of 10 means that pixels are 10 times larger
22 |
23 | Returns:
24 | np.array: The downsampled array in the rasterio (c, h, w) convention
25 | rio.DatasetReader: The reader with the transform updated
26 | rio.Transform: The updated transform
27 | """
28 | # Open the dataset handler. Note that this doesn't read into memory.
29 | dataset = rio.open(dataset_filename)
30 |
31 | # resample data to target shape
32 | data = dataset.read(
33 | out_shape=(
34 | dataset.count,
35 | int(dataset.height / downsample_factor),
36 | int(dataset.width / downsample_factor),
37 | ),
38 | resampling=rio.enums.Resampling.bilinear,
39 | )
40 |
41 | # scale image transform
42 | updated_transform = dataset.transform * dataset.transform.scale(
43 | (dataset.width / data.shape[-1]), (dataset.height / data.shape[-2])
44 | )
45 | # Return the data and the transform
46 | return data, dataset, updated_transform
47 |
48 |
49 | def reproject_raster(
50 | input_file: PATH_TYPE, output_file: PATH_TYPE, dst_crs: PATH_TYPE
51 | ) -> PATH_TYPE:
52 | """_summary_
53 |
54 | Args:
55 | input_file (PATH_TYPE): _description_
56 | output_file (PATH_TYPE): _description_
57 | dst_crs (PATH_TYPE): _description_
58 |
59 | Returns:
60 | PATH_TYPE: _description_
61 | """
62 | # Taken from here: https://rasterio.readthedocs.io/en/latest/topics/reproject.html
63 | # Open the source raster
64 | with rasterio.open(input_file, "r") as src:
65 | # If it is in the desired CRS, then return the input file since this is a no-op
66 | if dst_crs == src.crs:
67 | return input_file
68 |
69 | # Calculate the parameters of the transform for the data in the new CRS
70 | transform, width, height = calculate_default_transform(
71 | src.crs, dst_crs, src.width, src.height, *src.bounds
72 | )
73 | kwargs = src.meta.copy()
74 | # Create updated metadata for this new file
75 | kwargs.update(
76 | {"crs": dst_crs, "transform": transform, "width": width, "height": height}
77 | )
78 |
79 | # Open the output file
80 | with rasterio.open(output_file, "w", **kwargs) as dst:
81 | # Perform reprojection per band
82 | for i in range(1, src.count + 1):
83 | reproject(
84 | source=rasterio.band(src, i),
85 | destination=rasterio.band(dst, i),
86 | src_transform=src.transform,
87 | src_crs=src.crs,
88 | dst_transform=transform,
89 | dst_crs=dst_crs,
90 | resampling=Resampling.nearest,
91 | )
92 | # Return the output file path that the reprojected raster was written to.
93 | return output_file
94 |
95 |
96 | def show_raster(
97 | raster_file_path: PATH_TYPE,
98 | downsample_factor: float = 10.0,
99 | plt_ax: Optional[plt.axes] = None,
100 | CRS: Optional[pyproj.CRS] = None,
101 | ):
102 | """Show a raster, optionally downsampling or reprojecting it
103 |
104 | Args:
105 | raster_file_path (PATH_TYPE):
106 | Path to the raster file
107 | downsample_factor (float):
108 | How much to downsample the raster before visualization, this makes it faster and consume
109 | less memory
110 | plt_ax (Optional[plt.axes], optional):
111 | Axes to plot on, otherwise the current ones are used. Defaults to None.
112 | CRS (Optional[pyproj.CRS], optional):
113 | The CRS to reproject the data to if set. Defaults to None.
114 | """
115 | # Check if the file is georeferenced
116 | with rasterio.open(raster_file_path) as src:
117 | if src.crs is None:
118 | crs_available = False
119 | else:
120 | crs_available = True
121 |
122 | # Handle cases where no CRS is available
123 | if not crs_available and CRS is not None:
124 | print(f"Warning: No CRS found in the raster. Proceeding without reprojection.")
125 | CRS = None
126 |
127 | # If the CRS is set, ensure the data matches it
128 | if CRS is not None:
129 | # Create a temporary file to write to
130 | temp_output_filename = tempfile.NamedTemporaryFile(suffix=".tif")
131 | # Get the name of this file
132 | temp_name = temp_output_filename.name
133 | # Reproject the raster. If the CRS was the same as requested, the original raster path will
134 | # be returned. Otherwise, the reprojected raster will be written to the temp file and that
135 | # path will be returned.
136 | raster_file_path = reproject_raster(
137 | input_file=raster_file_path, output_file=temp_name, dst_crs=CRS
138 | )
139 | # Load the downsampled image
140 | img, _, transform = load_downsampled_raster_data(
141 | raster_file_path, downsample_factor=downsample_factor
142 | )
143 | # Plot the image
144 | rio.plot.show(source=img, transform=transform, ax=plt_ax)
145 |
146 |
147 | def plot_from_dataloader(sample):
148 | """
149 | Plots an image from the dataset.
150 |
151 | Args:
152 | sample (dict): A dictionary containing the tile to plot. The 'image' key should have a tensor of shape (C, H, W).
153 |
154 | Returns:
155 | matplotlib.figure.Figure: A figure containing the plotted image.
156 | """
157 | # Reorder and rescale the image
158 | image = sample["image"].permute(1, 2, 0).numpy()
159 |
160 | # Create the figure to plot on and return
161 | fig, ax = plt.subplots()
162 | # Plot differently based on the number of channels
163 | n_channels = image.shape[2]
164 |
165 | # Show with a colorbar and default matplotlib mapping for scalar data
166 | if n_channels == 1:
167 | cbar = ax.imshow(image)
168 | plt.colorbar(cbar, ax=ax)
169 | # Plot as RGB(A)
170 | elif n_channels in (3, 4):
171 | if image.dtype != np.uint8:
172 | # See if this should be interpreted as data 0-255, even if it's float data
173 | max_val = np.max(image)
174 | # If the values are greater than 1, assume it's supposed to be unsigned int8
175 | # and cast to that so it's properly shown
176 | if max_val > 1:
177 | image = image.astype(np.uint8)
178 | # Plot the image
179 | ax.imshow(image)
180 | else:
181 | raise ValueError(f"Cannot plot image with {n_channels} channels")
182 |
183 | return fig
184 |
--------------------------------------------------------------------------------