├── .gitignore ├── LICENSE ├── README.md ├── figures ├── confusion_matrix.png ├── confusion_matrix_simplified.png ├── naip_example.png ├── naip_masks_example.png ├── osm_logo.png ├── rsc_diagram.drawio ├── rsc_diagram.drawio.png ├── rsc_road_small.svg └── samples │ ├── correct │ ├── paved │ │ ├── sample_00000.png │ │ ├── sample_00001.png │ │ ├── sample_00002.png │ │ ├── sample_00003.png │ │ └── sample_00004.png │ └── unpaved │ │ ├── sample_00000.png │ │ ├── sample_00001.png │ │ ├── sample_00002.png │ │ ├── sample_00003.png │ │ └── sample_00004.png │ └── incorrect │ ├── paved │ ├── sample_00000.png │ ├── sample_00001.png │ ├── sample_00002.png │ ├── sample_00003.png │ └── sample_00004.png │ └── unpaved │ ├── sample_00000.png │ ├── sample_00001.png │ ├── sample_00002.png │ ├── sample_00003.png │ └── sample_00004.png ├── notebooks ├── 00_create_naip_on_aws_gpkg.ipynb ├── 01_explore_osm_surface.ipynb ├── 02_data_prep.ipynb ├── 03_dataset_to_gpkg.ipynb ├── 04_road_color_analysis.ipynb ├── 05_segmentation.ipynb ├── 06_hierarchical_loss.ipynb └── 07_model_calibration.ipynb ├── pyproject.toml ├── requirements.in ├── requirements.txt ├── rsc ├── artifacts │ ├── __init__.py │ ├── __main__.py │ ├── accuracy_obsc_handler.py │ ├── auc_handler.py │ ├── base.py │ ├── confusion_matrix_handler.py │ ├── obsc_compare_handler.py │ └── samples_handler.py ├── common │ ├── __init__.py │ ├── aws_naip.py │ ├── geometric_median.py │ └── utils.py ├── inference │ ├── __init__.py │ ├── fetch.py │ └── mass_inference_dataset.py ├── osm │ ├── README.md │ ├── __init__.py │ ├── osm_element.py │ ├── osm_element_factory.py │ ├── osm_network.py │ ├── osm_overpass_api.py │ └── overpass_api │ │ ├── __init__.py │ │ ├── osm_overpass_api.py │ │ └── road_network.py └── train │ ├── __init__.py │ ├── color_jitter_nohuesat.py │ ├── data_augmentation.py │ ├── dataset.py │ ├── mcnn.py │ ├── mcnn_loss.py │ ├── patch.py │ ├── plmcnn.py │ ├── preprocess.py │ └── rsc_hxe_loss.py └── scripts ├── evaluate.py ├── parse_features.py ├── perform_query.py ├── pre_inference.py ├── run_overpass_api.sh ├── sample_augmentations.py ├── train.py └── train_optuna.py /.gitignore: -------------------------------------------------------------------------------- 1 | .style.yapf 2 | *_wip\.py 3 | cache/ 4 | .vscode 5 | .devcontainer 6 | *\.pyc 7 | data/ 8 | *.code-workspace 9 | *.tar.gz 10 | *.npz 11 | build/ 12 | *.egg-info 13 | venv/ 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jonathan Dalrymple 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenStreetMap Road Surface Classifier 2 | 3 | OSM Logo OSM Logo 4 | 5 | This project leverages machine learning (ML) to tag "drivable ways" (roads) in OpenStreetMap **with > 95% accuracy**. The main focus at the moment is automating tagging of the road surface type (paved vs. unpaved; but skill is shown for asphalt, concrete, gravel, etc.), with other helpful tags such as number of lanes to be added in the future. 6 | 7 | > :warning: **Much of this code is under active development. Breaking changes are to be expected.** 8 | 9 | ## Table of Contents 10 | - [Why Should I Care...](#why-should-i-care) 11 | - [Integration with the OSM Community](#integration-with-the-osm-community) 12 | - [Python Environment](#python-environment) 13 | - [Training Dataset](#training-dataset) 14 | - [Model Architecture](#model-architecture) 15 | - [Model Training](#model-training) 16 | - [Training Results](#training-results) 17 | - [Credits](#credits) 18 | 19 | ## Why Should I Care... 20 | 21 | ### ... About OpenStreetMap? 22 | 23 | Most geospatial data, like that provided in Google Maps, are locked behind licensing and paywalls. Maps should not be proprietary. OpenStreetMap aims to change this, by providing open geospatial data for everyone. [See more here](https://wiki.openstreetmap.org/wiki/FAQ#Why_OpenStreetMap?). 24 | 25 | ### ... About this Project? 26 | 27 | Road surface type is critical for routing applications to generate useful routes. For example, nominal driving speeds are much slower on unpaved roads vs. paved roads. For bicycles, unpaved routes may need to be avoided completely. In any case, lacking knowledge of road surface type can lead any OSM-based routing application to choose suboptimal routes [if the assumed default surface values are incorrect](https://wiki.openstreetmap.org/wiki/Key:surface#Default_values). Widespread labeling of road surface types can increase overall confidence in OSM-based routers as a viable routing solution for cars and bicycles alike. 28 | 29 | ## Integration with the OSM Community 30 | 31 | I foresee the following possible integrations with the OpenStreetMap community: 32 | - Provide a dataset that can augment OSM routing solutions, such as [Project OSRM](https://project-osrm.org/), [Valhalla](https://github.com/valhalla/valhalla), and [cycle.travel](https://cycle.travel/). 33 | - Integrate the above dataset into editors such as [JOSM](https://josm.openstreetmap.de/) or [Rapid](https://rapideditor.org/). I'm prototyping a plugin for JOSM at the moment. 34 | 35 | ## Python Environment 36 | I recommend [pip-tools](https://github.com/jazzband/pip-tools) to manage the environment for this project. The gist for to set up a Python environment for this repo would be: 37 | ```bash 38 | $ cd /.../road_surface_classifier # cd into project directory 39 | $ python3 -m venv ./venv # create Python virtual env 40 | $ . venv/bin/activate # activate env 41 | (venv) $ python3 -m pip install pip-tools # install pip-tools 42 | (venv) $ pip-sync # automatically installs dependencies! 43 | ``` 44 | 45 | ## Training Dataset 46 | The dataset used in this project was prepared by the process outlined in [project notebooks](./notebooks). It is the aggregation of OpenStreetMap data w/ [National Agriculture Imagery Program (NAIP)](https://www.usgs.gov/centers/eros/science/usgs-eros-archive-aerial-photography-national-agriculture-imagery-program-naip) imagery, which is public domain. I additionally have a routine that generates pseudo-truth segmentation masks which the model learns to predict. 47 | 48 | Currently the dataset has been automatically generated, but in some cases has issues related to the visibility of the roads due to vegetation growth (though the model is trained to predict this). Also of note is that there is no guarantee the labels set in OSM are correct, so we must trust that the vast majority of them are correct. A big source of confusion, for example, is asphalt vs. concrete. It would not surprise me if there are many mislabeled examples of these within OSM. 49 | 50 | 51 | > :heavy_exclamation_mark: NAIP imagery is limited to the United States. While there are other public domain imagery sources that can be used, none have global coverage. 52 | 53 |
54 | NAIP Imagery Example 55 |
Examples of NAIP imagery over roads. These are Web Mercator tiles at zoom level 16 (~2.3 meters per pixel). However, the model is trained on raw NAIP data which includes NIR and is 1 meter per pixel. (source: USGS National Agriculture Imagery Program)
56 |
57 | 58 | To support the MaskCNN architecture (_see below_), binary masks were also generated in order to tell the model "where to look": 59 | 60 |
61 | NAIP Imagery + Masks Example 62 |
Examples of NAIP imagery over roads with OSM labels (paved vs. unpaved) and generated binary masks from OSM data. (source: USGS National Agriculture Imagery Program [imagery]; OpenStreetMap [labels])
63 |
64 | 65 | ## Model Architecture 66 | 67 | I'm currently using a MaskCNN model largely based on [Liu et al.: _Masked convolutional neural network for supervised learning problems_](https://par.nsf.gov/servlets/purl/10183705). 68 | - Instead of multiplication, I concatenate the predicted mask into the classifier backbone. 69 | - I'm using a Resnet-18 backbone for both the encoder, decoder, and classifier. 70 | - By using such a small encoder, this can inference on a CPU! 71 |
72 | MCNN Figure 73 |
Quick diagram of the MaskCNN architecture used here. The NAIP imagery gets combined with a mask created from OSM vector data, which in-turn is used to generate the segmentation mask. The image and segmentation mask are then fed into the classifier model.
74 |
75 | 76 | The benefit of this model over a plain Resnet is the ability to tell the model what the mask should look like. This tells the classifier "where to look" (i.e. I care about _this_ road in the image, not _that_ one). 77 | 78 | The trick is to not force the appearance of this mask too much, because then (1) the model stops looking outside the mask after the concatenation step and (2) the model will care about the mask more than the classification result! 79 | 80 | ## Model Training 81 | 82 | Training is currently done w/ [PyTorch Lightning](https://www.pytorchlightning.ai/), see [`train.py`](./model/train.py). 83 | 84 | > :heavy_exclamation_mark: I don't recommend training this model without a dedicated compute GPU configured with CUDA. I know some use [Google Colab](https://colab.research.google.com/), but I'm unfamiliar. 85 | 86 | ## Training Results 87 | 88 | **To read the confusion matrix**, for each box, read "when the true label is X, then __% of the time the models predicts Y". 89 | 90 | ### Paved vs. Unpaved 91 | 92 |
93 | Confusion matrix for paved vs. unpaved model 94 |
Not bad! The model gets each category right over 95% of the time.
95 |
96 | 97 | ### Multiclass 98 |
99 | Confusion matrix for full multiclass model 100 |
Given the imagery resolution, often obscuration of vegetation, and often incorrect truth labels this is impressive. The model clearly shows skill in predicting a wider range of classes than just paved vs. unpaved.
101 |
102 | 103 | ## Credits 104 | 105 | #### MaskCNN Paper 106 | Liu, L. Y. F., Liu, Y., & Zhu, H. (2020). Masked convolutional neural network for supervised learning problems. Stat, 9(1), e290. 107 | 108 | ## License 109 | [MIT](https://choosealicense.com/licenses/mit/) © 2024 Jonathan Dalrymple -------------------------------------------------------------------------------- /figures/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/confusion_matrix.png -------------------------------------------------------------------------------- /figures/confusion_matrix_simplified.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/confusion_matrix_simplified.png -------------------------------------------------------------------------------- /figures/naip_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/naip_example.png -------------------------------------------------------------------------------- /figures/naip_masks_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/naip_masks_example.png -------------------------------------------------------------------------------- /figures/osm_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/osm_logo.png -------------------------------------------------------------------------------- /figures/rsc_diagram.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/rsc_diagram.drawio.png -------------------------------------------------------------------------------- /figures/rsc_road_small.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 8 | 11 | 12 | 14 | 16 | 18 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /figures/samples/correct/paved/sample_00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/paved/sample_00000.png -------------------------------------------------------------------------------- /figures/samples/correct/paved/sample_00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/paved/sample_00001.png -------------------------------------------------------------------------------- /figures/samples/correct/paved/sample_00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/paved/sample_00002.png -------------------------------------------------------------------------------- /figures/samples/correct/paved/sample_00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/paved/sample_00003.png -------------------------------------------------------------------------------- /figures/samples/correct/paved/sample_00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/paved/sample_00004.png -------------------------------------------------------------------------------- /figures/samples/correct/unpaved/sample_00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/unpaved/sample_00000.png -------------------------------------------------------------------------------- /figures/samples/correct/unpaved/sample_00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/unpaved/sample_00001.png -------------------------------------------------------------------------------- /figures/samples/correct/unpaved/sample_00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/unpaved/sample_00002.png -------------------------------------------------------------------------------- /figures/samples/correct/unpaved/sample_00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/unpaved/sample_00003.png -------------------------------------------------------------------------------- /figures/samples/correct/unpaved/sample_00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/unpaved/sample_00004.png -------------------------------------------------------------------------------- /figures/samples/incorrect/paved/sample_00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/paved/sample_00000.png -------------------------------------------------------------------------------- /figures/samples/incorrect/paved/sample_00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/paved/sample_00001.png -------------------------------------------------------------------------------- /figures/samples/incorrect/paved/sample_00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/paved/sample_00002.png -------------------------------------------------------------------------------- /figures/samples/incorrect/paved/sample_00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/paved/sample_00003.png -------------------------------------------------------------------------------- /figures/samples/incorrect/paved/sample_00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/paved/sample_00004.png -------------------------------------------------------------------------------- /figures/samples/incorrect/unpaved/sample_00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/unpaved/sample_00000.png -------------------------------------------------------------------------------- /figures/samples/incorrect/unpaved/sample_00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/unpaved/sample_00001.png -------------------------------------------------------------------------------- /figures/samples/incorrect/unpaved/sample_00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/unpaved/sample_00002.png -------------------------------------------------------------------------------- /figures/samples/incorrect/unpaved/sample_00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/unpaved/sample_00003.png -------------------------------------------------------------------------------- /figures/samples/incorrect/unpaved/sample_00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/unpaved/sample_00004.png -------------------------------------------------------------------------------- /notebooks/00_create_naip_on_aws_gpkg.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "### NAIP On AWS\n", 9 | "\n", 10 | "This Jupyter notebook provides code to scrape the [NAIP on AWS](https://registry.opendata.aws/naip/) manifest and create a GPKG file that provides a geospatial footprint for all available imagery.\n", 11 | "\n", 12 | "Then, if we are interested in a given OSM ID, it's easy to look up which NAIP images intersect this ID and download them directly." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stderr", 22 | "output_type": "stream", 23 | "text": [ 24 | "100%|██████████| 1150/1150 [00:00<00:00, 125232.36it/s]\n" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "import pathlib\n", 30 | "\n", 31 | "from tqdm import tqdm\n", 32 | "from osgeo import ogr, osr\n", 33 | "ogr.UseExceptions()\n", 34 | "\n", 35 | "from rsc.common import aws_naip\n", 36 | "from rsc.common.aws_naip import AWS_PATH\n", 37 | "\n", 38 | "# Get NAIP manifest\n", 39 | "manifest = aws_naip.get_naip_manifest()\n", 40 | "\n", 41 | "# Filter out shapefiles to download\n", 42 | "shp = [e for e in manifest if e.split('.')[-1].lower() in \\\n", 43 | " ('shp', 'dbf', 'shx', 'prj', 'sbn')]\n", 44 | "\n", 45 | "# Fetch all the shapefiles\n", 46 | "for object_name in tqdm(shp):\n", 47 | " aws_naip.get_naip_file(object_name)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "def conv(s: str) -> str:\n", 57 | " \"\"\" Convert object paths as seen in manifest to those that might be seen in the shapefiles.\n", 58 | " It's silly they don't match.\"\"\"\n", 59 | " return '%s.tif' % '_'.join(s.split('_')[:6])\n", 60 | "\n", 61 | "# Read the manifest, and convert the TIF files to those that might be seen in the shapefile metadata\n", 62 | "with open(AWS_PATH / 'manifest.txt', 'r') as f:\n", 63 | " mani = {conv(pathlib.Path(p).stem): p for p in (e.strip() for e in f.readlines()) if p.endswith('.tif')}" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "# Do all the work!\n", 73 | "\n", 74 | "# Create SRS (EPSG:4326: WGS-84 decimal degrees)\n", 75 | "srs = osr.SpatialReference()\n", 76 | "srs.ImportFromEPSG(4326)\n", 77 | "srs.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)\n", 78 | "\n", 79 | "# Create GPKG file for writing\n", 80 | "driver: ogr.Driver = ogr.GetDriverByName('GPKG')\n", 81 | "ds_w: ogr.DataSource = driver.CreateDataSource(str(AWS_PATH / 'naip_on_aws.gpkg'))\n", 82 | "layer_w: ogr.Layer = ds_w.CreateLayer('footprints', srs=srs, geom_type=ogr.wkbPolygon)\n", 83 | "\n", 84 | "# Define output fields\n", 85 | "state_field = ogr.FieldDefn('STATE', ogr.OFTString)\n", 86 | "band_field = ogr.FieldDefn('BAND', ogr.OFTString)\n", 87 | "usgs_id_field = ogr.FieldDefn('USGSID', ogr.OFTString)\n", 88 | "src_img_date_field = ogr.FieldDefn('SRCIMGDATE', ogr.OFTString)\n", 89 | "filename_field = ogr.FieldDefn('FILENAME', ogr.OFTString)\n", 90 | "object_field = ogr.FieldDefn('OBJECT', ogr.OFTString)\n", 91 | "\n", 92 | "# Create output fields in layer\n", 93 | "layer_w.CreateField(state_field)\n", 94 | "layer_w.CreateField(band_field)\n", 95 | "layer_w.CreateField(usgs_id_field)\n", 96 | "layer_w.CreateField(src_img_date_field)\n", 97 | "layer_w.CreateField(filename_field)\n", 98 | "layer_w.CreateField(object_field)\n", 99 | "\n", 100 | "# Get layer feature definition to load in features\n", 101 | "feat_defn = layer_w.GetLayerDefn()\n", 102 | "\n", 103 | "# Loop through all fetched shapefiles\n", 104 | "for p in AWS_PATH.rglob('*.shp'):\n", 105 | "\n", 106 | " # Load them in OGR, get layer and spatial reference\n", 107 | " ds_r: ogr.DataSource = ogr.Open(str(p))\n", 108 | " layer_r: ogr.Layer = ds_r.GetLayer()\n", 109 | " srs_r: osr.SpatialReference = layer_r.GetSpatialRef()\n", 110 | " srs_r.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)\n", 111 | "\n", 112 | " # Loop throught the features in the layer\n", 113 | " for _ in range(layer_r.GetFeatureCount()):\n", 114 | " feat_r: ogr.Feature = layer_r.GetNextFeature()\n", 115 | "\n", 116 | " # Quickly crosscheck with manifest. Skip if not in there\n", 117 | " filename = feat_r.GetFieldAsString('FileName')\n", 118 | " filename_conv = conv(filename.split('.')[0])\n", 119 | " if not filename_conv in mani:\n", 120 | " continue\n", 121 | "\n", 122 | " # Parse remaining metadata\n", 123 | " try:\n", 124 | " state = feat_r.GetFieldAsString('ST')\n", 125 | " except RuntimeError:\n", 126 | " state = feat_r.GetFieldAsString('QUADST')\n", 127 | " band = feat_r.GetFieldAsString('Band')\n", 128 | " usgs_id = feat_r.GetFieldAsString('USGSID')\n", 129 | " src_img_date = feat_r.GetFieldAsString('SrcImgDate')\n", 130 | "\n", 131 | " # Fetch geometry and convert to desired spatial reference\n", 132 | " trans = osr.CoordinateTransformation(srs_r, srs)\n", 133 | " geom = ogr.CreateGeometryFromWkt(feat_r.GetGeometryRef().ExportToWkt())\n", 134 | " geom.Transform(trans)\n", 135 | "\n", 136 | " # Create our new feature\n", 137 | " feat_w = ogr.Feature(feat_defn)\n", 138 | " feat_w.SetGeometry(geom)\n", 139 | " feat_w.SetField('STATE', state)\n", 140 | " feat_w.SetField('BAND', band)\n", 141 | " feat_w.SetField('USGSID', usgs_id)\n", 142 | " feat_w.SetField('SRCIMGDATE', src_img_date)\n", 143 | " feat_w.SetField('FILENAME', filename)\n", 144 | " feat_w.SetField('OBJECT', mani[filename_conv])\n", 145 | "\n", 146 | " # Save!\n", 147 | " layer_w.CreateFeature(feat_w)\n", 148 | "\n", 149 | " # Cleanup features\n", 150 | " feat_w = None\n", 151 | " feat_r = None\n", 152 | "\n", 153 | " # Cleanup read dataset\n", 154 | " layer_r = None\n", 155 | " ds_r = None\n", 156 | "\n", 157 | "# Cleanup write dataset\n", 158 | "layer_w = None\n", 159 | "ds_w = None" 160 | ] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "Python 3.10.8 64-bit", 166 | "language": "python", 167 | "name": "python3" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.10.8" 180 | }, 181 | "orig_nbformat": 4, 182 | "vscode": { 183 | "interpreter": { 184 | "hash": "4d102384ded633c24f1031e288c2ecf1ababc4ef37e402995ad37064232eefd1" 185 | } 186 | } 187 | }, 188 | "nbformat": 4, 189 | "nbformat_minor": 2 190 | } 191 | -------------------------------------------------------------------------------- /notebooks/01_explore_osm_surface.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "### Exploring OSM `surface` tags\n", 9 | "\n", 10 | "In this notebook, I perform a custom OSM Overpass API query for all \"driveable\" roads that have surface labels. My main curiosity is which `surface` tags appear the most, and if I can identify an easy set of `surface` tag values that will account for the overwhelming majority of roads in the US." 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "First, I kick off a Docker container to host the OSM Overpass API. This query requires far too much data to use any Overpass API that is hosted online." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 20, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "7b463d293ec20d6fab599eb3b2dd15f0463d89f6192df9ff25a9ec32d8317b6e\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%%bash\n", 36 | "# https://hub.docker.com/r/wiktorn/overpass-api\n", 37 | "# http://localhost:12345/api/interpreter\n", 38 | "docker run \\\n", 39 | " -e OVERPASS_META=yes \\\n", 40 | " -e OVERPASS_MODE=init \\\n", 41 | " -e OVERPASS_PLANET_URL=file:///data/gis/us-latest.osm.bz2 \\\n", 42 | " -e OVERPASS_RULES_LOAD=10 \\\n", 43 | " -e OVERPASS_SPACE=55000000000 \\\n", 44 | " -e OVERPASS_MAX_TIMEOUT=86400 \\\n", 45 | " -v /data/gis:/data/gis \\\n", 46 | " -v /data/gis/overpass_db:/db \\\n", 47 | " -p 12345:80 \\\n", 48 | " -d --rm --name overpass_usa wiktorn/overpass-api:latest" 49 | ] 50 | }, 51 | { 52 | "attachments": {}, 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "Let's define a class that will allow us to perform an OSM Overpass API query for drivable road networks:" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 14, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from rsc.osm.overpass_api import OSMOverpassQuery, OSMOverpassResult\n", 66 | "\n", 67 | "\n", 68 | "class OSMCustomOverpassQuery(OSMOverpassQuery):\n", 69 | " \"\"\" Custom OSM Overpass API query for (hopefully) drivable road networks \"\"\"\n", 70 | "\n", 71 | " __slots__ = ['_highway_tags']\n", 72 | "\n", 73 | " DEFAULT_HIGHWAY_TAGS = [\n", 74 | " 'motorway', 'motorway_link', 'motorway_junction', 'trunk',\n", 75 | " 'trunk_link', 'primary', 'primary_link', 'secondary', 'secondary_link',\n", 76 | " 'tertiary', 'tertiary_link', 'unclassified', 'residential'\n", 77 | " ]\n", 78 | "\n", 79 | " def __init__(self, **kwargs):\n", 80 | " super().__init__(**kwargs)\n", 81 | " self._highway_tags = kwargs.get('highway_tags',\n", 82 | " self.DEFAULT_HIGHWAY_TAGS)\n", 83 | "\n", 84 | " def perform_query(self) -> 'OSMOverpassResult':\n", 85 | " \"\"\" Perform an OSM Overpass API Request! \"\"\"\n", 86 | " return OSMOverpassResult(self._perform_query())\n", 87 | "\n", 88 | " @property\n", 89 | " def _query_str(self) -> str:\n", 90 | " return f\"\"\"\n", 91 | " [out:{self._format}]\n", 92 | " [timeout:{self._timeout}]\n", 93 | " [maxsize:2147483648];\n", 94 | " (way[\"highway\"]\n", 95 | " [\"area\"!~\"yes\"]\n", 96 | " [\"access\"!~\"private\"]\n", 97 | " [\"highway\"~\"{'|'.join(self._highway_tags)}\"]\n", 98 | " [\"motor_vehicle\"!~\"no\"]\n", 99 | " [\"motorcar\"!~\"no\"]\n", 100 | " [\"surface\"!~\"\"]\n", 101 | " [\"service\"!~\"alley|driveway|emergency_access|parking|parking_aisle|private\"]\n", 102 | " (poly:'{self._poly_query_str}');\n", 103 | " >;\n", 104 | " );\n", 105 | " out;\n", 106 | " \"\"\"" 107 | ] 108 | }, 109 | { 110 | "attachments": {}, 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "Now we can perform the query. This is a *very* broad query and therefore takes quite a bit of time." 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 17, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "OSM data found at /data/gis/us_road_surface/us_w_road_surface.osm\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "import pathlib\n", 132 | "out_path = pathlib.Path('/data/gis/us_road_surface/us_w_road_surface.osm')\n", 133 | "\n", 134 | "if not out_path.is_file():\n", 135 | " # Setup custom query to local interpreter\n", 136 | " # Set a very long timeout\n", 137 | " q = OSMCustomOverpassQuery(format='xml', timeout=24 * 60 * 60)\n", 138 | " q.set_endpoint('http://localhost:12345/api/interpreter')\n", 139 | "\n", 140 | " # Use rough USA bounds for query\n", 141 | " with open('/data/gis/us_wkt.txt', 'r') as f:\n", 142 | " us_wkt = f.read()\n", 143 | " q.set_poly_from_wkt(us_wkt)\n", 144 | "\n", 145 | " # Perform query and save! This will take a long time.\n", 146 | " print('Performing query...')\n", 147 | " result = q.perform_query()\n", 148 | " print('Saving to file...')\n", 149 | " result.to_file(out_path)\n", 150 | "else:\n", 151 | " print('OSM data found at %s' % str(out_path))" 152 | ] 153 | }, 154 | { 155 | "attachments": {}, 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "At this point we can stop our Docker container hosting the Overpass API:" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 21, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "overpass_usa\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "%%bash\n", 177 | "docker stop overpass_usa" 178 | ] 179 | }, 180 | { 181 | "attachments": {}, 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "And now we use ogr2ogr to convert the OSM file we downloaded to a CSV file for easier GIS processing (the OGR OSM driver is very limited). This can be done in Python too, but the command line tool is easier for simple file conversions." 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 31, 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "File found: /data/gis/us_road_surface/us_w_road_surface.gpkg\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "%%bash\n", 203 | "DATA_DIR=/data/gis/us_road_surface\n", 204 | "OSM_PATH=$DATA_DIR/us_w_road_surface.osm\n", 205 | "GPKG_PATH=$DATA_DIR/us_w_road_surface.gpkg\n", 206 | "if [ ! -f $GPKG_PATH ]; then\n", 207 | " echo \"Converting $OSM_PATH to $GPKG_PATH...\"\n", 208 | " ogr2ogr $GPKG_PATH $OSM_PATH lines\n", 209 | "else\n", 210 | " echo \"File found: $GPKG_PATH\"\n", 211 | "fi" 212 | ] 213 | }, 214 | { 215 | "attachments": {}, 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "Let's parse the GPKG file to understand what surface types we are dealing with:" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 2, 225 | "metadata": {}, 226 | "outputs": [ 227 | { 228 | "name": "stdout", 229 | "output_type": "stream", 230 | "text": [ 231 | "Loading & filtering dataset features...\n" 232 | ] 233 | }, 234 | { 235 | "name": "stderr", 236 | "output_type": "stream", 237 | "text": [ 238 | "100%|██████████| 2458128/2458128 [02:15<00:00, 18099.85it/s]\n" 239 | ] 240 | } 241 | ], 242 | "source": [ 243 | "from osgeo import gdal, osr\n", 244 | "from tqdm import tqdm\n", 245 | "\n", 246 | "# Variable to store feature data that we load\n", 247 | "feature_data = []\n", 248 | "\n", 249 | "# WGS84 Spatial Reference: all of OSM is EPSG:4326\n", 250 | "srs_wgs84 = osr.SpatialReference()\n", 251 | "srs_wgs84.ImportFromEPSG(4326)\n", 252 | "\n", 253 | "# Dataset is ogr OSM parsed file with \"lines\" layer exported\n", 254 | "ds = gdal.OpenEx('/data/gis/us_road_surface/us_w_road_surface.gpkg')\n", 255 | "layer = ds.GetLayer()\n", 256 | "feature_count = layer.GetFeatureCount()\n", 257 | "print('Loading & filtering dataset features...')\n", 258 | "for idx in tqdm(range(feature_count)):\n", 259 | " # Get geometry, OSM ID, highway, and surface tag from each way\n", 260 | " feature = layer.GetNextFeature()\n", 261 | " highway = str(feature.GetField(2))\n", 262 | " wkt_str = feature.GetGeometryRef().ExportToWkt()\n", 263 | " osm_id = int(feature.GetField(0))\n", 264 | " other_tags = str(feature.GetField(8))\n", 265 | "\n", 266 | " # NOTE: parsing the misc. tags field is messy. This is about\n", 267 | " # as good as it gets.\n", 268 | " tags_dict = dict([[f.replace('\"', '') for f in e.split('\"=>\"')]\n", 269 | " for e in other_tags.split('\",\"')])\n", 270 | " surface_type = tags_dict.get('surface', 'unknown')\n", 271 | "\n", 272 | " # Add to the feature data\n", 273 | " feature_data.append([osm_id, wkt_str, highway, surface_type])\n", 274 | "\n", 275 | "# Close dataset\n", 276 | "layer = None\n", 277 | "ds = None" 278 | ] 279 | }, 280 | { 281 | "attachments": {}, 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "Empirically, if I take any `surface` labels that appear more than 1000 times, I get a reasonable set of labels that covers the fast majority of cases. Nice!" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 3, 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "data": { 295 | "text/plain": [ 296 | "asphalt 1639190\n", 297 | "unpaved 266035\n", 298 | "paved 230889\n", 299 | "concrete 153371\n", 300 | "gravel 105499\n", 301 | "dirt 28363\n", 302 | "concrete:plates 14310\n", 303 | "compacted 8585\n", 304 | "paving_stones 2562\n", 305 | "ground 2353\n", 306 | "bricks 1751\n", 307 | "Name: surface, dtype: int64" 308 | ] 309 | }, 310 | "execution_count": 3, 311 | "metadata": {}, 312 | "output_type": "execute_result" 313 | } 314 | ], 315 | "source": [ 316 | "import pandas as pd\n", 317 | "\n", 318 | "df = pd.DataFrame(feature_data, columns=['osm_id', 'wkt', 'highway', 'surface']).set_index('osm_id')\n", 319 | "unique_surface = df['surface'].value_counts()\n", 320 | "unique_surface = unique_surface[unique_surface > 1000]\n", 321 | "unique_surface" 322 | ] 323 | }, 324 | { 325 | "attachments": {}, 326 | "cell_type": "markdown", 327 | "metadata": {}, 328 | "source": [ 329 | "Let's see what percentage of `surface`-tagged roads this set of labels covers:" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 4, 335 | "metadata": {}, 336 | "outputs": [ 337 | { 338 | "name": "stdout", 339 | "output_type": "stream", 340 | "text": [ 341 | "Filtered surface tags account for 99.8% of all driveable ways.\n" 342 | ] 343 | } 344 | ], 345 | "source": [ 346 | "pc = unique_surface.sum() / len(feature_data)\n", 347 | "print(f'Filtered surface tags account for {pc:.1%} of all driveable ways.')" 348 | ] 349 | }, 350 | { 351 | "attachments": {}, 352 | "cell_type": "markdown", 353 | "metadata": {}, 354 | "source": [ 355 | "I can live with that.\n", 356 | "\n", 357 | "Let's save these filtered tags into a GPKG file for our classifier dataset prep." 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "# Start a new dataset\n", 367 | "print('Saving filtered features...')\n", 368 | "driver = ogr.GetDriverByName('GPKG')\n", 369 | "ds = driver.CreateDataSource(str(filtered_gpkg_path))\n", 370 | "layer = ds.CreateLayer('roads', srs=srs_wgs84, geom_type=ogr.wkbLineString)\n", 371 | "\n", 372 | "# Define fields\n", 373 | "id_field = ogr.FieldDefn('osmid', ogr.OFTInteger64)\n", 374 | "highway_field = ogr.FieldDefn('highway', ogr.OFTString)\n", 375 | "surface_field = ogr.FieldDefn('surface', ogr.OFTString)\n", 376 | "for field in (id_field, highway_field, surface_field):\n", 377 | " layer.CreateField(field)\n", 378 | "\n", 379 | "# Add features\n", 380 | "feature_defn = layer.GetLayerDefn()\n", 381 | "for idx, (osm_id, wkt_str, highway,\n", 382 | " surface_type) in tqdm(enumerate(feature_data)):\n", 383 | "\n", 384 | " # New feature\n", 385 | " feat = ogr.Feature(feature_defn)\n", 386 | "\n", 387 | " # Set geometry\n", 388 | " geom = ogr.CreateGeometryFromWkt(wkt_str)\n", 389 | " feat.SetGeometry(geom)\n", 390 | "\n", 391 | " # Set fields\n", 392 | " feat.SetField('osmid', osm_id)\n", 393 | " feat.SetField('highway', highway)\n", 394 | " feat.SetField('surface', surface_type)\n", 395 | "\n", 396 | " # Flush\n", 397 | " layer.CreateFeature(feat)\n", 398 | " feat = None\n", 399 | "\n", 400 | "# Close dataset\n", 401 | "layer = None\n", 402 | "ds = None" 403 | ] 404 | } 405 | ], 406 | "metadata": { 407 | "kernelspec": { 408 | "display_name": "venv", 409 | "language": "python", 410 | "name": "python3" 411 | }, 412 | "language_info": { 413 | "codemirror_mode": { 414 | "name": "ipython", 415 | "version": 3 416 | }, 417 | "file_extension": ".py", 418 | "mimetype": "text/x-python", 419 | "name": "python", 420 | "nbconvert_exporter": "python", 421 | "pygments_lexer": "ipython3", 422 | "version": "3.10.10" 423 | }, 424 | "orig_nbformat": 4 425 | }, 426 | "nbformat": 4, 427 | "nbformat_minor": 2 428 | } 429 | -------------------------------------------------------------------------------- /notebooks/03_dataset_to_gpkg.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "This is a quick notebook to allow me to plot the locations of all the images used in my dataset in QGIS." 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 3, 14 | "metadata": {}, 15 | "outputs": [ 16 | { 17 | "data": { 18 | "text/html": [ 19 | "
\n", 20 | "\n", 33 | "\n", 34 | " \n", 35 | " \n", 36 | " \n", 37 | " \n", 38 | " \n", 39 | " \n", 40 | " \n", 41 | " \n", 42 | " \n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | "
highwaysurfacelonlatobject
osm_id
391020748primaryasphalt-117.65836033.596795ca/2020/60cm/rgbir/33117/m_3311727_ne_11_060_2...
14290622primary_linkasphalt-115.14747336.092396nv/2019/60cm/rgbir/36115/m_3611563_ne_11_060_2...
240534097motorwayasphalt-78.73924742.929866ny/2019/60cm/rgbir/42078/m_4207803_sw_17_060_2...
684838122secondaryasphalt-112.01296933.444280az/2017/60cm/rgbir/33112/m_3311240_ne_12_h_201...
13567313unclassifiedasphalt-88.84320432.081660ms/2020/60cm/rgbir/32088/m_3208858_nw_16_060_2...
..................
13853468unclassifiedunpaved-117.03971947.376109wa/2019/60cm/rgbir/47117/m_4711740_se_11_060_2...
14121032unclassifiedunpaved-98.94799442.185050ne/2020/60cm/rgbir/42098/m_4209849_sw_14_060_2...
8757986residentialunpaved-74.78723844.285550ny/2019/60cm/rgbir/44074/m_4407442_se_18_060_2...
19717227residentialunpaved-72.00388844.995127vt/2018/60cm/rgbir/44072/m_4407208_ne_18_060_2...
14125476unclassifiedunpaved-98.53651741.241572ne/2020/60cm/rgbir/41098/m_4109852_ne_14_060_2...
\n", 143 | "

46491 rows × 5 columns

\n", 144 | "
" 145 | ], 146 | "text/plain": [ 147 | " highway surface lon lat \\\n", 148 | "osm_id \n", 149 | "391020748 primary asphalt -117.658360 33.596795 \n", 150 | "14290622 primary_link asphalt -115.147473 36.092396 \n", 151 | "240534097 motorway asphalt -78.739247 42.929866 \n", 152 | "684838122 secondary asphalt -112.012969 33.444280 \n", 153 | "13567313 unclassified asphalt -88.843204 32.081660 \n", 154 | "... ... ... ... ... \n", 155 | "13853468 unclassified unpaved -117.039719 47.376109 \n", 156 | "14121032 unclassified unpaved -98.947994 42.185050 \n", 157 | "8757986 residential unpaved -74.787238 44.285550 \n", 158 | "19717227 residential unpaved -72.003888 44.995127 \n", 159 | "14125476 unclassified unpaved -98.536517 41.241572 \n", 160 | "\n", 161 | " object \n", 162 | "osm_id \n", 163 | "391020748 ca/2020/60cm/rgbir/33117/m_3311727_ne_11_060_2... \n", 164 | "14290622 nv/2019/60cm/rgbir/36115/m_3611563_ne_11_060_2... \n", 165 | "240534097 ny/2019/60cm/rgbir/42078/m_4207803_sw_17_060_2... \n", 166 | "684838122 az/2017/60cm/rgbir/33112/m_3311240_ne_12_h_201... \n", 167 | "13567313 ms/2020/60cm/rgbir/32088/m_3208858_nw_16_060_2... \n", 168 | "... ... \n", 169 | "13853468 wa/2019/60cm/rgbir/47117/m_4711740_se_11_060_2... \n", 170 | "14121032 ne/2020/60cm/rgbir/42098/m_4209849_sw_14_060_2... \n", 171 | "8757986 ny/2019/60cm/rgbir/44074/m_4407442_se_18_060_2... \n", 172 | "19717227 vt/2018/60cm/rgbir/44072/m_4407208_ne_18_060_2... \n", 173 | "14125476 ne/2020/60cm/rgbir/41098/m_4109852_ne_14_060_2... \n", 174 | "\n", 175 | "[46491 rows x 5 columns]" 176 | ] 177 | }, 178 | "execution_count": 3, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "# Import data\n", 185 | "\n", 186 | "import pathlib\n", 187 | "\n", 188 | "import pandas as pd\n", 189 | "\n", 190 | "features_path = pathlib.Path('/data/road_surface_classifier/features.csv')\n", 191 | "df_feat = pd.read_csv(features_path).set_index('osm_id')\n", 192 | "df_feat = df_feat.drop(columns=['wkt', 'x', 'y', 'ix', 'iy', 'length'])\n", 193 | "df_feat" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 7, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "# Create dataset of points\n", 203 | "from osgeo import ogr, osr\n", 204 | "\n", 205 | "# Create SRS (EPSG:4326: WGS-84 decimal degrees)\n", 206 | "srs = osr.SpatialReference()\n", 207 | "srs.ImportFromEPSG(4326)\n", 208 | "\n", 209 | "driver: ogr.Driver = ogr.GetDriverByName('GPKG')\n", 210 | "ds: ogr.DataSource = driver.CreateDataSource('/data/road_surface_classifier/features_pts.gpkg')\n", 211 | "layer: ogr.Layer = ds.CreateLayer('data', srs=srs, geom_type=ogr.wkbPoint)\n", 212 | "\n", 213 | "osm_id_field = ogr.FieldDefn('osm_id', ogr.OFTInteger64)\n", 214 | "highway_field = ogr.FieldDefn('highway', ogr.OFTString)\n", 215 | "surface_field = ogr.FieldDefn('surface', ogr.OFTString)\n", 216 | "\n", 217 | "layer.CreateField(osm_id_field)\n", 218 | "layer.CreateField(highway_field)\n", 219 | "layer.CreateField(surface_field)\n", 220 | "\n", 221 | "feature_defn = layer.GetLayerDefn()\n", 222 | "\n", 223 | "for _, row in df_feat.iterrows():\n", 224 | " feat = ogr.Feature(feature_defn)\n", 225 | "\n", 226 | " pt = ogr.Geometry(ogr.wkbPoint)\n", 227 | " pt.AddPoint_2D(row['lon'], row['lat'])\n", 228 | "\n", 229 | " feat.SetGeometry(pt)\n", 230 | " feat.SetField('osm_id', row.name)\n", 231 | " feat.SetField('highway', row['highway'])\n", 232 | " feat.SetField('surface', row['surface'])\n", 233 | " layer.CreateFeature(feat)\n", 234 | " pt = None\n", 235 | " feat = None\n", 236 | "\n", 237 | "layer = None # type: ignore\n", 238 | "ds = None # type: ignore" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 6, 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "name": "stderr", 248 | "output_type": "stream", 249 | "text": [ 250 | "Exception ignored in: \n", 251 | "Traceback (most recent call last):\n", 252 | " File \"/tmp/ipykernel_57433/3662002279.py\", line 16, in \n", 253 | "RuntimeError: sqlite3_exec(CREATE TABLE gpkg_extensions (table_name TEXT,column_name TEXT,extension_name TEXT NOT NULL,definition TEXT NOT NULL,scope TEXT NOT NULL,CONSTRAINT ge_tce UNIQUE (table_name, column_name, extension_name))) failed: attempt to write a readonly database\n" 254 | ] 255 | } 256 | ], 257 | "source": [ 258 | "# Create dataset of polygons from imagery\n", 259 | "import json\n", 260 | "from osgeo import gdal, ogr, osr\n", 261 | "\n", 262 | "gdal.UseExceptions()\n", 263 | "ogr.UseExceptions()\n", 264 | "\n", 265 | "naip_img_path = pathlib.Path('/data/road_surface_classifier/imagery')\n", 266 | "assert naip_img_path.is_dir()\n", 267 | "\n", 268 | "# Create SRS (EPSG:4326: WGS-84 decimal degrees)\n", 269 | "srs = osr.SpatialReference()\n", 270 | "srs.ImportFromEPSG(4326)\n", 271 | "\n", 272 | "driver: ogr.Driver = ogr.GetDriverByName('GPKG')\n", 273 | "ds: ogr.DataSource = driver.CreateDataSource('/data/road_surface_classifier/features_polys.gpkg')\n", 274 | "layer: ogr.Layer = ds.CreateLayer('data', srs=srs, geom_type=ogr.wkbPolygon)\n", 275 | "\n", 276 | "osm_id_field = ogr.FieldDefn('osm_id', ogr.OFTInteger64)\n", 277 | "highway_field = ogr.FieldDefn('highway', ogr.OFTString)\n", 278 | "surface_field = ogr.FieldDefn('surface', ogr.OFTString)\n", 279 | "\n", 280 | "layer.CreateField(osm_id_field)\n", 281 | "layer.CreateField(highway_field)\n", 282 | "layer.CreateField(surface_field)\n", 283 | "\n", 284 | "feature_defn = layer.GetLayerDefn()\n", 285 | "\n", 286 | "for osm_id, row in df_feat.iterrows():\n", 287 | " img_path = naip_img_path / str('%d.tif' % osm_id)\n", 288 | " assert img_path.exists()\n", 289 | "\n", 290 | " wgs84_extent = gdal.Info(str(img_path), format='json')['wgs84Extent']\n", 291 | " poly = ogr.CreateGeometryFromJson(json.dumps(wgs84_extent))\n", 292 | "\n", 293 | " feat = ogr.Feature(feature_defn)\n", 294 | "\n", 295 | " feat.SetGeometry(poly)\n", 296 | " feat.SetField('osm_id', row.name)\n", 297 | " feat.SetField('highway', row['highway'])\n", 298 | " feat.SetField('surface', row['surface'])\n", 299 | " layer.CreateFeature(feat)\n", 300 | " poly = None\n", 301 | " feat = None\n", 302 | "\n", 303 | "layer = None # type: ignore\n", 304 | "ds = None # type: ignore" 305 | ] 306 | } 307 | ], 308 | "metadata": { 309 | "kernelspec": { 310 | "display_name": "Python 3.10.6 ('rsd_env')", 311 | "language": "python", 312 | "name": "python3" 313 | }, 314 | "language_info": { 315 | "codemirror_mode": { 316 | "name": "ipython", 317 | "version": 3 318 | }, 319 | "file_extension": ".py", 320 | "mimetype": "text/x-python", 321 | "name": "python", 322 | "nbconvert_exporter": "python", 323 | "pygments_lexer": "ipython3", 324 | "version": "3.10.6" 325 | }, 326 | "orig_nbformat": 4, 327 | "vscode": { 328 | "interpreter": { 329 | "hash": "3c3e3338979283bc5980811c64bc074b42c7e88e72bf9f1fd3a7107f9ec2dee1" 330 | } 331 | } 332 | }, 333 | "nbformat": 4, 334 | "nbformat_minor": 2 335 | } 336 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "road_surface_classifer" 3 | version = "0.9.0" 4 | description = "Automated surface tagging of roads in OpenStreetMap" 5 | readme = "README.md" 6 | requires-python = ">=3.7" 7 | license = { file = "LICENSE" } 8 | keywords = ["openstreetmap", "machine learning", "classifier"] 9 | authors = [{ name = "Jon Dalrymple", email = "j_dalrym2@hotmail.com" }] 10 | maintainers = [{ name = "Jon Dalrymple", email = "j_dalrym2@hotmail.com" }] 11 | 12 | # Classifiers help users find your project by categorizing it. 13 | # 14 | # For a list of valid classifiers, see https://pypi.org/classifiers/ 15 | classifiers = [ 16 | "Development Status :: 3 - Alpha", 17 | "Environment :: GPU", 18 | "Environment :: GPU :: NVIDIA CUDA", 19 | "Environment :: GPU :: NVIDIA CUDA :: 11", 20 | "Framework :: Jupyter", 21 | "Intended Audience :: Developers", 22 | "Intended Audience :: Science/Research", 23 | "License :: OSI Approved :: MIT License", 24 | "Natural Language :: English", 25 | "Operating System :: POSIX :: Linux", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.7", 28 | "Programming Language :: Python :: 3.8", 29 | "Programming Language :: Python :: 3.9", 30 | "Programming Language :: Python :: 3.10", 31 | "Programming Language :: Python :: 3.11", 32 | "Programming Language :: Python :: 3 :: Only", 33 | "Topic :: Scientific/Engineering", 34 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 35 | "Topic :: Scientific/Engineering :: GIS", 36 | "Topic :: Scientific/Engineering :: Image Processing", 37 | "Topic :: Scientific/Engineering :: Image Recognition", 38 | ] 39 | 40 | # This field lists other packages that your project depends on to run. 41 | # Any package you put here will be installed by pip when your project is 42 | # installed, so they must be valid existing projects. 43 | # 44 | # For an analysis of this field vs pip's requirements files see: 45 | # https://packaging.python.org/discussions/install-requires-vs-requirements/ 46 | # TODO: double-check these + split for rsc submodules 47 | # TODO: For now, use requirements.txt / requirements.in 48 | dependencies = [] 49 | 50 | # List additional groups of dependencies here (e.g. development 51 | # dependencies). Users will be able to install these using the "extras" 52 | # syntax, for example: 53 | # 54 | # $ pip install sampleproject[dev] 55 | # 56 | # Similar to `dependencies` above, these must be valid existing 57 | # projects. 58 | [project.optional-dependencies] 59 | dev = ["check-manifest"] 60 | test = ["coverage"] 61 | 62 | [project.urls] 63 | "Homepage" = "https://github.com/jdalrym2/road_surface_classifier" 64 | 65 | [tool.setuptools] 66 | packages = ["rsc"] 67 | 68 | [build-system] 69 | # These are the assumed default build requirements from pip: 70 | # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support 71 | requires = ["setuptools>=43.0.0", "wheel"] 72 | build-backend = "setuptools.build_meta" 73 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | boto3 2 | numpy 3 | GDAL>3,<=3.8.3 4 | kornia 5 | matplotlib 6 | pandas 7 | Pillow 8 | pytorch_lightning 9 | optuna>=3.1.0 10 | mlflow 11 | requests 12 | scikit_learn 13 | scikit_image 14 | scipy 15 | torch 16 | torchvision 17 | tqdm 18 | ipykernel 19 | ipympl 20 | ipywidgets 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.10 3 | # by the following command: 4 | # 5 | # pip-compile --output-file=requirements.txt requirements.in 6 | # 7 | aiohttp==3.8.4 8 | # via fsspec 9 | aiosignal==1.3.1 10 | # via aiohttp 11 | alembic==1.10.4 12 | # via 13 | # mlflow 14 | # optuna 15 | asttokens==2.2.1 16 | # via stack-data 17 | async-timeout==4.0.2 18 | # via aiohttp 19 | attrs==23.1.0 20 | # via aiohttp 21 | backcall==0.2.0 22 | # via ipython 23 | blinker==1.6.2 24 | # via flask 25 | boto3==1.26.129 26 | # via -r requirements.in 27 | botocore==1.29.129 28 | # via 29 | # boto3 30 | # s3transfer 31 | certifi==2023.5.7 32 | # via requests 33 | charset-normalizer==3.1.0 34 | # via 35 | # aiohttp 36 | # requests 37 | click==8.1.3 38 | # via 39 | # databricks-cli 40 | # flask 41 | # mlflow 42 | cloudpickle==2.2.1 43 | # via mlflow 44 | cmaes==0.9.1 45 | # via optuna 46 | cmake==3.26.3 47 | # via triton 48 | colorlog==6.7.0 49 | # via optuna 50 | comm==0.1.3 51 | # via ipykernel 52 | contourpy==1.0.7 53 | # via matplotlib 54 | cycler==0.11.0 55 | # via matplotlib 56 | databricks-cli==0.17.6 57 | # via mlflow 58 | debugpy==1.6.7 59 | # via ipykernel 60 | decorator==5.1.1 61 | # via ipython 62 | docker==6.1.0 63 | # via mlflow 64 | entrypoints==0.4 65 | # via mlflow 66 | executing==1.2.0 67 | # via stack-data 68 | filelock==3.12.0 69 | # via 70 | # torch 71 | # triton 72 | flask==2.3.2 73 | # via mlflow 74 | fonttools==4.39.3 75 | # via matplotlib 76 | frozenlist==1.3.3 77 | # via 78 | # aiohttp 79 | # aiosignal 80 | fsspec[http]==2023.5.0 81 | # via pytorch-lightning 82 | gdal==3.8.3 83 | # via -r requirements.in 84 | gitdb==4.0.10 85 | # via gitpython 86 | gitpython==3.1.31 87 | # via mlflow 88 | greenlet==2.0.2 89 | # via sqlalchemy 90 | gunicorn==20.1.0 91 | # via mlflow 92 | idna==3.4 93 | # via 94 | # requests 95 | # yarl 96 | imageio==2.30.0 97 | # via scikit-image 98 | importlib-metadata==6.6.0 99 | # via mlflow 100 | ipykernel==6.22.0 101 | # via 102 | # -r requirements.in 103 | # ipywidgets 104 | ipympl==0.9.3 105 | # via -r requirements.in 106 | ipython==8.13.2 107 | # via 108 | # ipykernel 109 | # ipympl 110 | # ipywidgets 111 | ipython-genutils==0.2.0 112 | # via ipympl 113 | ipywidgets==8.0.6 114 | # via 115 | # -r requirements.in 116 | # ipympl 117 | itsdangerous==2.1.2 118 | # via flask 119 | jedi==0.18.2 120 | # via ipython 121 | jinja2==3.1.2 122 | # via 123 | # flask 124 | # mlflow 125 | # torch 126 | jmespath==1.0.1 127 | # via 128 | # boto3 129 | # botocore 130 | joblib==1.2.0 131 | # via scikit-learn 132 | jupyter-client==8.2.0 133 | # via ipykernel 134 | jupyter-core==5.3.0 135 | # via 136 | # ipykernel 137 | # jupyter-client 138 | jupyterlab-widgets==3.0.7 139 | # via ipywidgets 140 | kiwisolver==1.4.4 141 | # via matplotlib 142 | kornia==0.6.12 143 | # via -r requirements.in 144 | lazy-loader==0.2 145 | # via scikit-image 146 | lightning-utilities==0.8.0 147 | # via pytorch-lightning 148 | lit==16.0.3 149 | # via triton 150 | mako==1.2.4 151 | # via alembic 152 | markdown==3.4.3 153 | # via mlflow 154 | markupsafe==2.1.2 155 | # via 156 | # jinja2 157 | # mako 158 | # werkzeug 159 | matplotlib==3.7.1 160 | # via 161 | # -r requirements.in 162 | # ipympl 163 | # mlflow 164 | matplotlib-inline==0.1.6 165 | # via 166 | # ipykernel 167 | # ipython 168 | mlflow==2.3.1 169 | # via -r requirements.in 170 | mpmath==1.3.0 171 | # via sympy 172 | multidict==6.0.4 173 | # via 174 | # aiohttp 175 | # yarl 176 | nest-asyncio==1.5.6 177 | # via ipykernel 178 | networkx==3.1 179 | # via 180 | # scikit-image 181 | # torch 182 | numpy==1.24.3 183 | # via 184 | # -r requirements.in 185 | # cmaes 186 | # contourpy 187 | # imageio 188 | # ipympl 189 | # matplotlib 190 | # mlflow 191 | # optuna 192 | # pandas 193 | # pyarrow 194 | # pytorch-lightning 195 | # pywavelets 196 | # scikit-image 197 | # scikit-learn 198 | # scipy 199 | # tifffile 200 | # torchmetrics 201 | # torchvision 202 | nvidia-cublas-cu11==11.10.3.66 203 | # via 204 | # nvidia-cudnn-cu11 205 | # nvidia-cusolver-cu11 206 | # torch 207 | nvidia-cuda-cupti-cu11==11.7.101 208 | # via torch 209 | nvidia-cuda-nvrtc-cu11==11.7.99 210 | # via torch 211 | nvidia-cuda-runtime-cu11==11.7.99 212 | # via torch 213 | nvidia-cudnn-cu11==8.5.0.96 214 | # via torch 215 | nvidia-cufft-cu11==10.9.0.58 216 | # via torch 217 | nvidia-curand-cu11==10.2.10.91 218 | # via torch 219 | nvidia-cusolver-cu11==11.4.0.1 220 | # via torch 221 | nvidia-cusparse-cu11==11.7.4.91 222 | # via torch 223 | nvidia-nccl-cu11==2.14.3 224 | # via torch 225 | nvidia-nvtx-cu11==11.7.91 226 | # via torch 227 | oauthlib==3.2.2 228 | # via databricks-cli 229 | optuna==3.1.1 230 | # via -r requirements.in 231 | packaging==23.1 232 | # via 233 | # docker 234 | # ipykernel 235 | # kornia 236 | # lightning-utilities 237 | # matplotlib 238 | # mlflow 239 | # optuna 240 | # pytorch-lightning 241 | # scikit-image 242 | # torchmetrics 243 | pandas==2.0.1 244 | # via 245 | # -r requirements.in 246 | # mlflow 247 | parso==0.8.3 248 | # via jedi 249 | pexpect==4.8.0 250 | # via ipython 251 | pickleshare==0.7.5 252 | # via ipython 253 | pillow==9.5.0 254 | # via 255 | # -r requirements.in 256 | # imageio 257 | # ipympl 258 | # matplotlib 259 | # scikit-image 260 | # torchvision 261 | platformdirs==3.5.0 262 | # via jupyter-core 263 | prompt-toolkit==3.0.38 264 | # via ipython 265 | protobuf==4.22.4 266 | # via mlflow 267 | psutil==5.9.5 268 | # via ipykernel 269 | ptyprocess==0.7.0 270 | # via pexpect 271 | pure-eval==0.2.2 272 | # via stack-data 273 | pyarrow==11.0.0 274 | # via mlflow 275 | pygments==2.15.1 276 | # via ipython 277 | pyjwt==2.6.0 278 | # via databricks-cli 279 | pyparsing==3.0.9 280 | # via matplotlib 281 | python-dateutil==2.8.2 282 | # via 283 | # botocore 284 | # jupyter-client 285 | # matplotlib 286 | # pandas 287 | pytorch-lightning==2.0.2 288 | # via -r requirements.in 289 | pytz==2023.3 290 | # via 291 | # mlflow 292 | # pandas 293 | pywavelets==1.4.1 294 | # via scikit-image 295 | pyyaml==6.0 296 | # via 297 | # mlflow 298 | # optuna 299 | # pytorch-lightning 300 | pyzmq==25.0.2 301 | # via 302 | # ipykernel 303 | # jupyter-client 304 | querystring-parser==1.2.4 305 | # via mlflow 306 | requests==2.30.0 307 | # via 308 | # -r requirements.in 309 | # databricks-cli 310 | # docker 311 | # fsspec 312 | # mlflow 313 | # torchvision 314 | s3transfer==0.6.1 315 | # via boto3 316 | scikit-image==0.20.0 317 | # via -r requirements.in 318 | scikit-learn==1.2.2 319 | # via 320 | # -r requirements.in 321 | # mlflow 322 | scipy==1.10.1 323 | # via 324 | # -r requirements.in 325 | # mlflow 326 | # scikit-image 327 | # scikit-learn 328 | six==1.16.0 329 | # via 330 | # asttokens 331 | # databricks-cli 332 | # python-dateutil 333 | # querystring-parser 334 | smmap==5.0.0 335 | # via gitdb 336 | sqlalchemy==2.0.12 337 | # via 338 | # alembic 339 | # mlflow 340 | # optuna 341 | sqlparse==0.4.4 342 | # via mlflow 343 | stack-data==0.6.2 344 | # via ipython 345 | sympy==1.11.1 346 | # via torch 347 | tabulate==0.9.0 348 | # via databricks-cli 349 | threadpoolctl==3.1.0 350 | # via scikit-learn 351 | tifffile==2023.4.12 352 | # via scikit-image 353 | torch==2.0.0 354 | # via 355 | # -r requirements.in 356 | # kornia 357 | # pytorch-lightning 358 | # torchmetrics 359 | # torchvision 360 | # triton 361 | torchmetrics==0.11.4 362 | # via pytorch-lightning 363 | torchvision==0.15.1 364 | # via -r requirements.in 365 | tornado==6.3.1 366 | # via 367 | # ipykernel 368 | # jupyter-client 369 | tqdm==4.65.0 370 | # via 371 | # -r requirements.in 372 | # optuna 373 | # pytorch-lightning 374 | traitlets==5.9.0 375 | # via 376 | # comm 377 | # ipykernel 378 | # ipympl 379 | # ipython 380 | # ipywidgets 381 | # jupyter-client 382 | # jupyter-core 383 | # matplotlib-inline 384 | triton==2.0.0 385 | # via torch 386 | typing-extensions==4.5.0 387 | # via 388 | # alembic 389 | # lightning-utilities 390 | # pytorch-lightning 391 | # sqlalchemy 392 | # torch 393 | tzdata==2023.3 394 | # via pandas 395 | urllib3==1.26.15 396 | # via 397 | # botocore 398 | # docker 399 | # requests 400 | wcwidth==0.2.6 401 | # via prompt-toolkit 402 | websocket-client==1.5.1 403 | # via docker 404 | werkzeug==2.3.3 405 | # via flask 406 | wheel==0.40.0 407 | # via 408 | # nvidia-cublas-cu11 409 | # nvidia-cuda-cupti-cu11 410 | # nvidia-cuda-runtime-cu11 411 | # nvidia-curand-cu11 412 | # nvidia-cusparse-cu11 413 | # nvidia-nvtx-cu11 414 | widgetsnbextension==4.0.7 415 | # via ipywidgets 416 | yarl==1.9.2 417 | # via aiohttp 418 | zipp==3.15.0 419 | # via importlib-metadata 420 | 421 | # The following packages are considered to be unsafe in a requirements file: 422 | # setuptools 423 | -------------------------------------------------------------------------------- /rsc/artifacts/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ RSC submodule to simplify artifact generation during inference over a dataset """ 4 | import pathlib 5 | 6 | import torch 7 | 8 | # Get PyTorch device to use 9 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | def find_best_model(results_dir: pathlib.Path) -> pathlib.Path: 13 | """ 14 | Find the best model checkpoint (by validation loss) in a results directory. 15 | 16 | Args: 17 | results_dir (pathlib.Path): Results directory 18 | 19 | Returns: 20 | pathlib.Path: Path to checkpoint with minimum validation loss. 21 | """ 22 | # Parse checkpoints from results dir 23 | paths = list(results_dir.glob('*.ckpt')) 24 | path_stems = [e.stem for e in paths] 25 | path_metrics = [e.split('-') for e in path_stems] 26 | # Exact validation losses from filenames 27 | val_losses = [] 28 | for idx, metrics in enumerate(path_metrics): 29 | for metric in metrics: 30 | if metric.startswith('val_loss'): 31 | val_loss = float(metric.split('=')[-1]) 32 | val_losses.append((idx, val_loss)) 33 | 34 | # Find path that has minimum validation loss 35 | # By reversing val losses, if there is a tie 36 | # we fetch the last one 37 | min_idx, _ = min(reversed(val_losses), key=lambda v: v[1]) 38 | 39 | return paths[min_idx] 40 | 41 | 42 | from .base import ArtifactGenerator, ArtifactHandler -------------------------------------------------------------------------------- /rsc/artifacts/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ Generate artifacts for a model """ 4 | 5 | import pathlib 6 | import argparse 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | 11 | from ..train.plmcnn import PLMaskCNN 12 | from ..train.dataset import RoadSurfaceDataset 13 | from ..train.preprocess import PreProcess 14 | 15 | from . import find_best_model 16 | from .base import ArtifactGenerator 17 | from .confusion_matrix_handler import ConfusionMatrixHandler 18 | from .accuracy_obsc_handler import AccuracyObscHandler 19 | from .obsc_compare_handler import ObscCompareHandler 20 | from .samples_handler import SamplesHandler 21 | from .auc_handler import AUCHandler 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description=__doc__) 26 | parser.add_argument('-i', 27 | '--input-ckpt', 28 | type=str, 29 | required=True, 30 | help='Path to input checkpoint file. ' 31 | 'If given a directory, attempts to ' 32 | 'find the best model in the directory.') 33 | parser.add_argument('-d', 34 | '--dataset-csv', 35 | type=str, 36 | required=True, 37 | help='Path to dataset CSV') 38 | parser.add_argument('-o', 39 | '--output-path', 40 | type=str, 41 | required=True, 42 | help='Path to output directory for artifacts. ' 43 | 'If it does not exist, one will be created.') 44 | parser.add_argument('-m', 45 | '--model-path', 46 | type=str, 47 | required=False, 48 | help='Path to model PTH file, ' 49 | 'if PLMaskCNN is not the model') 50 | parser.add_argument('-c', 51 | '--count', 52 | type=int, 53 | required=False, 54 | default=-1, 55 | help='Maximum of images to do ' 56 | 'inference on. ' 57 | '-1: inference all images') 58 | parser.add_argument('--batch-size', 59 | type=int, 60 | required=False, 61 | default=64, 62 | help='Set dataloader batch size.') 63 | parser.add_argument('--num-workers', 64 | type=int, 65 | required=False, 66 | default=16, 67 | help='Set dataloader worker count.') 68 | parser.add_argument('--no-shuffle', 69 | action='store_true', 70 | help='If specified, do not shuffle ' 71 | 'the dataset when loading.') 72 | parser.add_argument('--raise-on-error', 73 | action='store_true', 74 | help='If specified, stop processing when ' 75 | 'encountering an error.') 76 | 77 | return parser.parse_args() 78 | 79 | 80 | if __name__ == '__main__': 81 | 82 | pargs = parse_args() 83 | 84 | # Parse relevant inputs up front 85 | csv_path = pathlib.Path(pargs.dataset_csv) 86 | assert csv_path.is_file() 87 | 88 | # Determine model checkpoint 89 | ckpt_path = pathlib.Path(pargs.input_ckpt) 90 | if ckpt_path.is_file(): 91 | # Load the checkpoint as-is 92 | pass 93 | elif ckpt_path.is_dir(): 94 | # Find the best model from the directory 95 | ckpt_path = find_best_model(ckpt_path) 96 | print('Found best model: %s' % str(ckpt_path)) 97 | else: 98 | raise ValueError( 99 | 'Could not find checkpoint path: %s. Must be a file or directory.' 100 | % str(ckpt_path)) 101 | 102 | # Load model 103 | if pargs.model_path is None: 104 | model = PLMaskCNN.load_from_checkpoint(ckpt_path) 105 | else: 106 | model_path = pathlib.Path(pargs.model_path) 107 | assert model_path.is_file() 108 | model: PLMaskCNN = torch.load(model_path) 109 | model.load_from_checkpoint(ckpt_path) 110 | 111 | # Put model in eval mode 112 | model.eval() 113 | 114 | # Construct dataset 115 | val_ds = RoadSurfaceDataset(csv_path, 116 | transform=PreProcess(), 117 | n_channels=model.nc, 118 | limit=pargs.count) 119 | 120 | # Construct dataloader 121 | val_dl = DataLoader(val_ds, 122 | num_workers=pargs.num_workers, 123 | batch_size=pargs.batch_size, 124 | shuffle=not pargs.no_shuffle) 125 | 126 | # Parse save directory, create if not there 127 | save_dir = pathlib.Path(pargs.output_path) 128 | if not save_dir.is_dir(): 129 | save_dir.mkdir(parents=False) 130 | 131 | # Generate artifacts from model 132 | generator = ArtifactGenerator(save_dir, model, val_dl) 133 | generator.add_handler(ConfusionMatrixHandler(simple=True)) 134 | generator.add_handler(ConfusionMatrixHandler(simple=False)) 135 | generator.add_handler(AccuracyObscHandler()) 136 | generator.add_handler(ObscCompareHandler()) 137 | generator.add_handler(SamplesHandler()) 138 | # generator.add_handler(AUCHandler()) # disabled for multiclass 139 | generator.run(raise_on_error=pargs.raise_on_error) 140 | -------------------------------------------------------------------------------- /rsc/artifacts/accuracy_obsc_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import pathlib 5 | from typing import Any, Sequence 6 | 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | 11 | from torch.utils.data import DataLoader 12 | 13 | from .base import ArtifactHandler 14 | 15 | # Use non-GUI backend 16 | matplotlib.use('Agg') 17 | 18 | 19 | class AccuracyObscHandler(ArtifactHandler): 20 | """ Handler class to plot accuracy vs obscuration """ 21 | 22 | def __init__(self): 23 | super().__init__() 24 | 25 | self.labels: list[str] | None = None 26 | self.y_true_l = [] 27 | self.acc_l = [] 28 | self.y_true_obsc_l = [] 29 | 30 | def start(self, model: Any, dataloader: DataLoader) -> None: 31 | 32 | # Try to get labels from model 33 | self.labels = model.__dict__.get('labels') 34 | 35 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None: 36 | 37 | _, features = dl_iter 38 | features = features.cpu().detach().numpy() 39 | 40 | _, pred = model_out 41 | pred = pred.cpu().detach().numpy() 42 | 43 | # Get predicted label as argmax 44 | this_y_pred = np.argmax(pred[..., :-1], axis=1) 45 | this_y_true = np.argmax(features[..., :-1], axis=1) 46 | self.y_true_l.append(this_y_true) 47 | self.acc_l.append((this_y_pred == this_y_true).astype(int)) 48 | self.y_true_obsc_l.append(features[..., -1]) 49 | 50 | @staticmethod 51 | def compute_scores( 52 | n_labels: int, 53 | acc: np.ndarray, 54 | y_true: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]: 55 | """ 56 | Compute accuracy scores and relevant data for plotting. 57 | 58 | Args: 59 | n_labels (int): Number of labels in the dataset 60 | acc (np.ndarray): Boolean array where True -> correct result 61 | y_true (np.ndarray): The true label array. Must be the same size as `acc` 62 | 63 | Returns: 64 | tuple[float, np.ndarray, np.ndarray]: Output accuracy scores and relevant data. All in range [0, 1]. 65 | - `acc_all`: Accuracy over all data in `y_true` 66 | - `acc[label]`: Accuracy for [label] roads in `y_true` 67 | - `pc[label]`: Proportion of [label] in `y_true` 68 | """ 69 | assert len(acc) == len(y_true) 70 | 71 | # Count total + num correct 72 | num_total = len(acc) 73 | num_correct = np.count_nonzero(acc) 74 | 75 | # Compute accuracy scores 76 | try: 77 | acc_all = num_correct / num_total 78 | except ZeroDivisionError: 79 | acc_all = np.nan 80 | 81 | # Pre-allocate arrays for accuracy and proportion 82 | acc_l = np.zeros((n_labels,)) 83 | pc = np.zeros((n_labels,)) 84 | 85 | for label in range(n_labels): 86 | # Find indices relevant to the label 87 | idx_label = np.where(y_true == label)[0] 88 | 89 | # Compute accuracy of this label 90 | try: 91 | acc_l[label] = np.count_nonzero(acc[idx_label]) / len(idx_label) 92 | except ZeroDivisionError: 93 | acc_l[label] = np.nan 94 | 95 | # Compute proportion of this label 96 | try: 97 | pc[label] = len(idx_label) / len(acc) 98 | except ZeroDivisionError: 99 | pc[label] = np.nan 100 | 101 | return acc_all, acc_l, pc 102 | 103 | def save(self, output_dir) -> tuple[pathlib.Path, pathlib.Path]: 104 | 105 | assert self.labels is not None 106 | 107 | # Aggregate and organize 108 | y_true = np.concatenate(self.y_true_l) 109 | acc = np.concatenate(self.acc_l) 110 | y_true_obsc = np.concatenate(self.y_true_obsc_l) 111 | 112 | # Bins for obscuration: 0 -> 1 113 | bins = np.linspace(0, 1, 20 + 1) 114 | 115 | # Compute accuracy scores in bins 116 | acc_plt_b, counts_b = [], [] # binned 117 | acc_plt_c, counts_c = [], [] # cumulative 118 | for obsc_min, obsc_max in zip(bins[:-1], bins[1:]): 119 | # Get indices of interest (indices in bin + cumulative) 120 | idx_b = np.where((obsc_min < y_true_obsc) 121 | & (y_true_obsc <= obsc_max))[0] 122 | idx_c = np.where(y_true_obsc <= obsc_max)[0] 123 | 124 | # Compute accuracy scores and add to plot data: 125 | # Binned 126 | acc_all, acc_l, pc = self.compute_scores( 127 | len(self.labels), acc[idx_b], y_true[idx_b]) 128 | acc_plt_b.append( 129 | [e * 100 for e in (acc_all, *acc_l)]) 130 | counts_b.append(pc * 100) 131 | 132 | # Cumulative 133 | acc_all, acc_l, pc = self.compute_scores( 134 | len(self.labels), acc[idx_c], y_true[idx_c]) 135 | acc_plt_c.append( 136 | [e * 100 for e in (acc_all, *acc_l)]) 137 | counts_c.append(pc * 100) 138 | 139 | # Create the plots! 140 | # Binned 141 | fig, ax = plt.subplots(2, 1, sharex=True, figsize=(12, 9)) 142 | ax[0].plot(bins[1:] * 100, [e[0] for e in acc_plt_b], 143 | 'k-*', 144 | linewidth=2) 145 | for i in range(1, len(acc_plt_b[0])): 146 | ax[0].plot(bins[1:] * 100, [e[i] for e in acc_plt_b], '-*') 147 | ax[0].set_title('Prediction Accuracy vs. Obscuration (Binned)') 148 | ax[0].set_ylim(None, 100) # type: ignore 149 | ax[0].set_ylabel(r'Accuracy [%]') 150 | ax[0].grid() 151 | ax[1].plot(bins[1:] * 100, counts_b, '-*') 152 | ax[1].set_title('Label Proportion vs. Obscuration (Binned)') 153 | ax[1].set_ylim(0, 100) 154 | ax[1].grid() 155 | ax[1].set_ylabel(f'Proportion [%]') 156 | ax[1].set_xlabel(r'Obscuration [%]') 157 | fig.legend(('All', *[e.title() for e in self.labels])) 158 | binned_path = output_dir / 'acc_obsc_plot_binned.png' 159 | fig.savefig(str(binned_path)) 160 | plt.close(fig) 161 | 162 | # Cumulative 163 | fig, ax = plt.subplots(2, 1, sharex=True, figsize=(12, 9)) 164 | ax[0].plot(bins[1:] * 100, [e[0] for e in acc_plt_c], 165 | 'k-*', 166 | linewidth=2) 167 | for i in range(1, len(acc_plt_c[0])): 168 | ax[0].plot(bins[1:] * 100, [e[i] for e in acc_plt_c], '-*') 169 | ax[0].set_title('Prediction Accuracy vs. Obscuration (Cumulative)') 170 | ax[0].set_ylim(None, 100) # type: ignore 171 | ax[0].set_ylabel(r'Accuracy [%]') 172 | ax[0].grid() 173 | ax[1].plot(bins[1:] * 100, counts_c, '-*') 174 | ax[1].set_title('Label Proportion vs. Obscuration (Cumulative)') 175 | ax[1].set_ylim(0, 100) 176 | ax[1].grid() 177 | ax[1].set_ylabel(f'Proportion [%]') 178 | ax[1].set_xlabel(r'Obscuration [%]') 179 | fig.legend(('All', *[e.title() for e in self.labels])) 180 | cumul_path = output_dir / 'acc_obsc_plot_cumul.png' 181 | fig.savefig(str(cumul_path)) 182 | plt.close(fig) 183 | 184 | return binned_path, cumul_path 185 | -------------------------------------------------------------------------------- /rsc/artifacts/auc_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import pathlib 5 | from typing import Any, Sequence 6 | 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import sklearn.metrics 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | from .base import ArtifactHandler 16 | 17 | # Use non-GUI backend 18 | matplotlib.use('Agg') 19 | 20 | 21 | class AUCHandler(ArtifactHandler): 22 | """ Handler class to plot accuracy vs obscuration """ 23 | 24 | def __init__(self): 25 | super().__init__() 26 | 27 | self.y_true_l = [] # true labels 28 | self.y_pred_c = [] # confidences 29 | self.roc_auc = None 30 | 31 | def start(self, model: Any, dataloader: DataLoader) -> None: 32 | pass 33 | 34 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None: 35 | 36 | _, features = dl_iter 37 | features = features.cpu().detach().numpy() 38 | 39 | _, pred = model_out 40 | pred = torch.softmax(pred, dim=1) 41 | pred = pred.cpu().detach().numpy() 42 | 43 | # Get predicted label as argmax 44 | this_y_true = np.argmax(features[..., :-1], axis=1) 45 | self.y_true_l.append(this_y_true) 46 | self.y_pred_c.append(pred[..., :-1]) 47 | 48 | def save(self, output_dir) -> pathlib.Path: 49 | 50 | y_true_l = np.concatenate(self.y_true_l) 51 | # probability of the class with the *greater* / *positive* label 52 | y_pred_c = np.concatenate(self.y_pred_c)[:, 1] 53 | 54 | fig, ax = plt.subplots() 55 | ax.grid() 56 | disp = sklearn.metrics.RocCurveDisplay.from_predictions(y_true_l, 57 | y_pred_c, 58 | ax=ax) 59 | self.roc_auc = disp.roc_auc # type: ignore 60 | ax.set_title('ROC Curve') 61 | 62 | plt_path = output_dir / 'roc_curve.png' 63 | fig.savefig(str(plt_path)) 64 | plt.close(fig) 65 | 66 | return plt_path 67 | -------------------------------------------------------------------------------- /rsc/artifacts/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import pathlib 5 | import traceback 6 | from abc import ABC, abstractmethod 7 | from typing import Any, Sequence 8 | 9 | from tqdm import tqdm 10 | 11 | from torch.utils.data import DataLoader 12 | 13 | from . import device 14 | 15 | 16 | class ArtifactHandler(ABC): 17 | """ Abstract class used to generate an artifact from an inference result """ 18 | 19 | def __init__(self): 20 | pass 21 | 22 | @abstractmethod 23 | def start(self, model: Any, dataloader: DataLoader) -> None: 24 | """ 25 | Handler startup code. Do any initalization here. 26 | 27 | Args: 28 | model (Any): Model that will be used 29 | dataloader (DataLoader): Dataloader that will be used 30 | """ 31 | pass 32 | 33 | @abstractmethod 34 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None: 35 | """ 36 | Called each time a batch is inferenced. Collect data for 37 | the artifact here. 38 | 39 | Args: 40 | dl_iter (Sequence): Output iterable from the dataloader. 41 | model_out (Sequence): Output iterable from `model.predict` 42 | """ 43 | pass 44 | 45 | @abstractmethod 46 | def save(self, output_dir: pathlib.Path) -> pathlib.Path: 47 | """ 48 | Given the collected state from on_iter, generate and 49 | save the relevant artifact. 50 | 51 | Args: 52 | output_dir (pathlib.Path): Output directory to artifact 53 | 54 | Returns: 55 | pathlib.Path: Path to saved artifact 56 | """ 57 | pass 58 | 59 | 60 | class ArtifactGenerator: 61 | """ Top-level class to handle multiple artifact generation """ 62 | 63 | def __init__(self, output_dir: str | pathlib.Path, model: Any, 64 | dataloader: DataLoader): 65 | """ 66 | Create an artifact generator 67 | 68 | Args: 69 | output_dir (str | pathlib.Path): Output directory to dump artifacts 70 | model (Any): RSC model for inference 71 | dataloader (DataLoader): Dataloader to run inference over 72 | """ 73 | self.handlers: list[ArtifactHandler] = [] 74 | self.is_active: list[bool] = [] 75 | self.output_dir = pathlib.Path(output_dir) 76 | self.output_dir.mkdir(parents=False, exist_ok=True) 77 | self.model = model 78 | self.model.to(device) 79 | self.dataloader = dataloader 80 | 81 | def add_handler(self, handler: ArtifactHandler): 82 | """ 83 | Add a handler to the generator. All handlers should be added 84 | before `run` is called. 85 | 86 | Args: 87 | handler (ArtifactHandler): Handler to add to the generator. 88 | 89 | Raises: 90 | ValueError: If the handler is not a subclass of `ArtifactHandler` 91 | """ 92 | if not issubclass(handler.__class__, ArtifactHandler): 93 | raise ValueError( 94 | f'Input handler must be a subclass of {ArtifactHandler.__name__}' 95 | ) 96 | self.handlers.append(handler) 97 | 98 | def _exec(self, idx, func, raise_on_error: bool) -> Any: 99 | """ Wrapper call to work with the handler. Provides 100 | a failsafe in the case the handler throws an exception 101 | such that inference can continue """ 102 | # Skip an inactive handler 103 | if not self.is_active[idx]: 104 | return 105 | try: 106 | # Attempt call 107 | return func() 108 | except Exception: 109 | # If we failed, raise if we want 110 | if raise_on_error: 111 | raise 112 | # Otherwise, print the exception and deactivate 113 | # the handler 114 | traceback.print_exc() 115 | print(f'Deactivating handler {idx:d}...') 116 | self.is_active[idx] = False 117 | 118 | def run(self, raise_on_error: bool = False): 119 | """ 120 | Run inference and pass data to each handler. 121 | 122 | Args: 123 | raise_on_error (bool, optional): If any handler throws an error, 124 | whether or not to raise. Defaults to False. 125 | """ 126 | 127 | # Ready? 128 | self.is_active = [True for _ in self.handlers] 129 | 130 | # Set... 131 | for idx, h in enumerate(self.handlers): 132 | self._exec(idx, lambda: h.start(self.model, self.dataloader), 133 | raise_on_error) 134 | 135 | # Go! 136 | for dl_iter in tqdm(iter(self.dataloader)): 137 | x, _ = dl_iter 138 | 139 | # Extract just the image + location mask for inference 140 | # i.e. strip the last (probmask) channel 141 | x = x[:, :-1, :, :].to(device) 142 | 143 | # Get prediction from model 144 | model_out = self.model(x) 145 | 146 | # Run the handlers 147 | for idx, h in enumerate(self.handlers): 148 | self._exec(idx, lambda: h.on_iter(dl_iter, model_out), 149 | raise_on_error) 150 | 151 | # Finalize 152 | for idx, h in enumerate(self.handlers): 153 | self._exec(idx, lambda: h.save(self.output_dir), raise_on_error) 154 | -------------------------------------------------------------------------------- /rsc/artifacts/confusion_matrix_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import pathlib 5 | from typing import Any, Sequence 6 | 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import matplotlib.cm as cm 11 | from torch.utils.data import DataLoader 12 | from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay 13 | 14 | from .base import ArtifactHandler 15 | 16 | # Use non-GUI backend 17 | matplotlib.use('Agg') 18 | 19 | 20 | class ConfusionMatrixHandler(ArtifactHandler): 21 | """ Handler class to generate a confusion matrix """ 22 | 23 | def __init__(self, simple=False): 24 | super().__init__() 25 | 26 | # Adds "simple mode", where we are only 27 | # looking at paved / unpaved roads 28 | # I admit this short-circuits the 29 | # generalizability of this class, but it's 30 | # worth it 31 | self.simple = simple 32 | 33 | self.labels: list[str] | None = None 34 | self.y_true_l: list[Any] = [] 35 | self.y_pred_l: list[Any] = [] 36 | 37 | def start(self, model: Any, dataloader: DataLoader) -> None: 38 | 39 | # Try to get labels from model 40 | self.labels = model.__dict__.get('labels') 41 | if self.simple: 42 | self.labels = ['paved', 'unpaved'] 43 | 44 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None: 45 | 46 | # Parse iterable 47 | _, features = dl_iter 48 | 49 | # We extract just the image + location mask 50 | y_true = features.numpy()[..., :-1] 51 | if self.simple: 52 | y_true = np.sum(np.stack((y_true[..., 0:4], y_true[..., 4:8]), axis=1), axis=-1) 53 | y_true = np.argmax(y_true, axis=1) 54 | self.y_true_l.append(y_true) 55 | 56 | # Get prediction from model 57 | _, pred = model_out 58 | pred = pred.cpu().detach().numpy() 59 | 60 | # Lazy-init labels if we couldn't get them from the model 61 | if self.labels is None: 62 | self.labels = [f'Class {n+1:d}' for n in range(len(pred))] 63 | 64 | # Get predicted label as argmax 65 | if self.simple: 66 | y_pred = np.sum(np.stack((pred[..., 0:4], pred[..., 4:8]), axis=1), axis=-1) 67 | else: 68 | y_pred = pred[..., :-1] 69 | y_pred = np.argmax(y_pred, axis=1) 70 | self.y_pred_l.append(y_pred) 71 | 72 | def save(self, output_dir) -> pathlib.Path: 73 | 74 | # Aggregate and organize 75 | y_true = np.concatenate(self.y_true_l) 76 | y_pred = np.concatenate(self.y_pred_l) 77 | 78 | # Generate and save the confusion matrix 79 | c = ConfusionMatrixDisplay(confusion_matrix(y_true, 80 | y_pred, 81 | normalize='true'), 82 | display_labels=self.labels) 83 | if self.simple: 84 | output_path = output_dir / 'confusion_matrix_simple.png' 85 | else: 86 | output_path = output_dir / 'confusion_matrix.png' 87 | fig, ax = plt.subplots(figsize=(8, 8)) 88 | fig.subplots_adjust(left=0.2, bottom=0.2) 89 | try: 90 | c.plot(ax=ax, cmap=cm.Blues, # type: ignore 91 | xticks_rotation='vertical') 92 | except ValueError: 93 | print('Detected issue with plot! This is likely due to a mismatch of class labels and true labels.') 94 | raise 95 | fig.savefig(str(output_path)) 96 | plt.close(fig) 97 | 98 | return output_path -------------------------------------------------------------------------------- /rsc/artifacts/obsc_compare_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import pathlib 5 | from typing import Any, Sequence 6 | import torch 7 | 8 | import numpy as np 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | 12 | from torch.utils.data import DataLoader 13 | 14 | from .base import ArtifactHandler 15 | 16 | # Use non-GUI backend 17 | matplotlib.use('Agg') 18 | 19 | 20 | class ObscCompareHandler(ArtifactHandler): 21 | """ Handler class to plot predicted vs. actual obscuration """ 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | self.y_pred_l = [] 27 | self.y_true_l = [] 28 | 29 | def start(self, model: Any, dataloader: DataLoader) -> None: 30 | pass 31 | 32 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None: 33 | _, features = dl_iter 34 | features = features.cpu().detach().numpy() 35 | 36 | # Get prediction from model 37 | _, pred = model_out 38 | pred = torch.sigmoid(pred[..., -1]) 39 | pred = pred.cpu().detach().numpy() 40 | 41 | # Get predicted label as argmax 42 | self.y_pred_l.append(pred) 43 | self.y_true_l.append(features[..., -1]) 44 | 45 | def save(self, output_dir) -> pathlib.Path: 46 | y_pred = np.concatenate(self.y_pred_l) * 100. 47 | y_true = np.concatenate(self.y_true_l) * 100. 48 | 49 | # Create the plot! 50 | fig, ax = plt.subplots(figsize=(12, 12)) 51 | ax.set_aspect('equal') 52 | ax.scatter(y_pred, y_true, 9) 53 | ax.set_xlabel(r'Predicted Obscuration [%]') 54 | ax.set_ylabel(r'Est. Obscuration by Logit Regression [%]') 55 | ax.set_title('Model Obscuration Prediction Accuracy') 56 | ax.grid() 57 | ax.plot((0, 100), (0, 100), '--k', linewidth=2) 58 | ax.legend(['Model Data', 'y = x']) 59 | 60 | output_path = output_dir / 'obsc_compare_plot.png' 61 | fig.savefig(str(output_path)) 62 | plt.close(fig) 63 | 64 | return output_path 65 | -------------------------------------------------------------------------------- /rsc/artifacts/samples_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import pathlib 4 | from typing import Any, Sequence 5 | 6 | import numpy as np 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | from .base import ArtifactHandler 13 | 14 | # Use non-GUI backend 15 | matplotlib.use('Agg') 16 | 17 | 18 | class SamplesHandler(ArtifactHandler): 19 | 20 | def __init__(self, max_per_category: int = 50): 21 | super().__init__() 22 | self.max_per_category = max_per_category 23 | self.labels = [] 24 | self.samples = {} 25 | self.maxed_out = {} 26 | 27 | def start(self, model: Any, dataloader: DataLoader) -> None: 28 | # Try to get labels from model 29 | self.labels = model.__dict__.get('labels') 30 | assert self.labels is not None 31 | 32 | # Setup samples dictionary 33 | for correct in (True, False): 34 | self.samples[correct] = {} 35 | self.maxed_out[correct] = {} 36 | for label in self.labels: 37 | self.samples[correct][label] = [] 38 | self.maxed_out[correct][label] = False 39 | 40 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None: 41 | 42 | # Skip all this extra processing if we reached our maximum sample count anyway 43 | for label in self.labels: 44 | if any(not self.maxed_out[correct][label] 45 | for correct in (True, False)): 46 | break 47 | else: 48 | return # we did not break, so we must be maxed out 49 | 50 | x, features = dl_iter 51 | 52 | # Get mask and image 53 | xpm = x[:, (-1,), :, :] 54 | xm = x[:, (-2,), :, :] 55 | x = x[:, :-2, :, :] 56 | 57 | # Get true label 58 | y_true = features.cpu().detach().numpy() 59 | 60 | # Transform image and mask 61 | x_p = (np.moveaxis(x.numpy(), 1, -1) * 255.).astype( 62 | np.uint8) 63 | xm_p = np.moveaxis(xm.numpy(), 1, -1) 64 | xpm_p = np.moveaxis(xpm.numpy(), 1, -1) 65 | 66 | # Parse model prediction 67 | m, y_pred = model_out 68 | m_p = np.moveaxis(m.cpu().detach().numpy(), 1, -1) 69 | 70 | # True label (argmax) 71 | y_true_am: np.ndarray = np.argmax(y_true[:, 0:-1], 1) 72 | 73 | # Predicted label (argmax) 74 | y_pred_am = torch.argmax(y_pred[:, :-1], 1) 75 | y_pred_am = y_pred_am.cpu().detach().numpy() # type: ignore 76 | 77 | # Predicted obscuration (sigmoid) 78 | y_pred_obsc = torch.sigmoid(y_pred[:, -1]).cpu().detach().numpy() # type: ignore 79 | 80 | for i in range(x.shape[0]): 81 | 82 | # Are we correct? 83 | correct = y_true_am[i] == y_pred_am[i] 84 | 85 | # Get current true label 86 | this_label = self.labels[y_true_am[i]] 87 | 88 | # Skip if above limit for category 89 | if len(self.samples[correct][this_label]) > self.max_per_category: 90 | self.maxed_out[correct][this_label] = True 91 | continue 92 | 93 | # If we got here, then take a sample 94 | self.samples[correct][this_label].append(( 95 | x_p[i, ...], 96 | xm_p[i, ...], 97 | xpm_p[i, ...], 98 | m_p[i, ...], 99 | y_true[i, ...], 100 | y_true_am[i, ...], 101 | y_pred_am[i, ...], 102 | y_pred_obsc[i, ...], 103 | )) 104 | 105 | def save(self, output_dir: pathlib.Path) -> pathlib.Path: 106 | 107 | # Create directory for samples 108 | samples_dir = output_dir / 'samples' 109 | samples_dir.mkdir(parents=False, exist_ok=True) 110 | 111 | for correct in (True, False): 112 | # Create directory for "correctness" 113 | correct_dir = samples_dir / ('correct' if correct else 'incorrect') 114 | correct_dir.mkdir(parents=False, exist_ok=True) 115 | 116 | for label in self.labels: 117 | # Create directory for label 118 | label_dir = correct_dir / str(label) 119 | label_dir.mkdir(parents=False, exist_ok=True) 120 | 121 | # Loop through all the samples 122 | for idx, (x_p, xm_p, xpm_p, m_p, \ 123 | y_true, y_true_am, y_pred_am, \ 124 | y_pred_obsc) in enumerate(self.samples[correct][label]): 125 | 126 | # Get num-channels (3 - RGB, 4 - RGB + NIR) 127 | _, _, nc = x_p.shape 128 | 129 | # Create figure 130 | fig, ax = plt.subplots(1, 131 | nc + 1, 132 | sharex=True, 133 | sharey=True, 134 | figsize=(15, 3.5)) 135 | for _ax in ax: 136 | _ax.xaxis.set_visible(False) 137 | _ax.yaxis.set_visible(False) 138 | 139 | # Allows us to support different 140 | # number of subplots in NIR case 141 | ax_idx = 0 142 | 143 | # Plot the result, highlighting the road with the mask 144 | ax[ax_idx].set_title( 145 | 'RGB\nTrue: %s; Pred: %s' % 146 | (self.labels[y_true_am], self.labels[y_pred_am])) 147 | ax[ax_idx].imshow(np.uint8(x_p[..., (0, 1, 2)])) 148 | ax_idx += 1 149 | 150 | # Plot color IR if relevant 151 | if nc > 3: 152 | ax[ax_idx].set_title('Color IR') 153 | ax[ax_idx].imshow(np.uint8(x_p[..., (3, 0, 1)])) 154 | ax_idx += 1 155 | 156 | ax[ax_idx].set_title('Combined Image + Mask\nObsc: %.1f%%; Pred: %.1f%%' % (y_true[-1] * 100, y_pred_obsc * 100)) 157 | ax[ax_idx].imshow( 158 | np.uint8(x_p[..., (0, 1, 2)] * (0.33 + 0.67 * xm_p))) 159 | ax_idx += 1 160 | 161 | ax[ax_idx].set_title('Combined Image \n+ True ProbMask') 162 | ax[ax_idx].imshow(np.uint8(0.5 * x_p[..., (0, 1, 2)])) 163 | ax[ax_idx].imshow(xpm_p, 164 | cmap='magma', 165 | vmin=0, 166 | vmax=1, 167 | alpha=0.33) 168 | ax_idx += 1 169 | 170 | ax[ax_idx].set_title('Combined Image \n+ Pred ProbMask') 171 | ax[ax_idx].imshow(np.uint8(0.5 * x_p[..., (0, 1, 2)])) 172 | ax[ax_idx].imshow(m_p, cmap='magma', vmin=0, vmax=1, alpha=0.33) 173 | 174 | # Save the figure 175 | output_path = label_dir / f'sample_{idx:05d}.png' 176 | fig.savefig(str(output_path)) 177 | plt.close(fig) 178 | 179 | return samples_dir -------------------------------------------------------------------------------- /rsc/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/rsc/common/__init__.py -------------------------------------------------------------------------------- /rsc/common/aws_naip.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ Some helpful functions to download NAIP data on AWS """ 4 | 5 | import pathlib 6 | from typing import List 7 | 8 | import boto3 9 | 10 | BASE_PATH = pathlib.Path('/data/road_surface_classifier') 11 | assert BASE_PATH.is_dir() 12 | AWS_PATH = BASE_PATH / 'naip_on_aws' 13 | AWS_PATH.mkdir(exist_ok=True) 14 | 15 | 16 | def naip_s3_fetch(bucket_name: str, object_name: str, output: pathlib.Path): 17 | """ 18 | Fetch an S3 object from a bucket, ensuring that RequestPayer=requester is set 19 | 20 | Args: 21 | bucket_name (str): AWS S3 bucket name 22 | object_name (str): Object name in bucket 23 | output (pathlib.Path): Path to save the output file 24 | 25 | Raises: 26 | FileExistsError: If the output file already exists 27 | """ 28 | # Prevent overwriting the output file 29 | if output.exists(): 30 | raise FileExistsError('File already exists: %s' % str(output)) 31 | 32 | # Use the boto3 s3 API to fetch the file object 33 | s3_client = boto3.client('s3') 34 | try: 35 | with open(output, 'wb') as f: 36 | s3_client.download_fileobj(bucket_name, 37 | object_name, 38 | f, 39 | ExtraArgs={'RequestPayer': 'requester'}) 40 | except: 41 | # If cancelled or something bad happens, 42 | # unlink the file we are trying to download and raise 43 | output.unlink() 44 | raise 45 | 46 | 47 | def get_naip_manifest(bucket_name='naip-analytic') -> List[str]: 48 | """ 49 | Read the bucket manifest for a given NAIP on AWS bucket. If the manifest 50 | doesn't yet exist, then fetch it from S3. 51 | 52 | Args: 53 | bucket_name (str, optional): NAIP on AWS bucket name. Defaults to 'naip-analytic'. 54 | 55 | Returns: 56 | List[str]: List of objects in manifest 57 | """ 58 | manifest_path = AWS_PATH / 'manifest.txt' 59 | 60 | # Fetch the manifest if it doesn't exist. 61 | if not manifest_path.exists(): 62 | naip_s3_fetch(bucket_name, 'manifest.txt', manifest_path) 63 | 64 | # Read it in! 65 | with open(manifest_path, 'r') as f: 66 | return [e.strip() for e in f.readlines()] 67 | 68 | 69 | def get_naip_file(object_name: str, 70 | bucket_name='naip-analytic') -> pathlib.Path: 71 | """ 72 | Get a NAIP file by object name. If it already exists, this will not download again. 73 | 74 | Args: 75 | object_name (str): Object name 76 | bucket_name (str, optional): NAIP on AWS Bucket name. Defaults to 'naip-analytic'. 77 | 78 | Returns: 79 | pathlib.Path: Path to the downloaded file. 80 | """ 81 | # Get output path, return immediately if exists 82 | output_path = AWS_PATH / object_name 83 | if output_path.exists(): 84 | return output_path 85 | 86 | # Otherwise, create the parent directory, and fetch the file 87 | output_path.parent.mkdir(parents=True, exist_ok=True) 88 | naip_s3_fetch(bucket_name, object_name, output_path) 89 | 90 | return output_path 91 | -------------------------------------------------------------------------------- /rsc/common/geometric_median.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ Python implementation of geometric median as described in 4 | Yehuda Vardi and Cun-Hui Zhang's paper 5 | "The multivariate L1-median and associated data depth" 6 | """ 7 | from typing import Optional 8 | import numpy as np 9 | from scipy.spatial.distance import cdist, euclidean 10 | 11 | 12 | def geometric_median(X: np.ndarray, eps: float = 1e-5) -> np.ndarray: 13 | """ 14 | Python implementation of geometric median as described in 15 | Yehuda Vardi and Cun-Hui Zhang's paper 16 | "The multivariate L1-median and associated data depth" 17 | 18 | Taken from: https://stackoverflow.com/questions/30299267/geometric-median-of-multidimensional-points 19 | 20 | Released under zlib license. 21 | 22 | Args: 23 | X (np.ndarray): Input N-dimensional data 24 | eps (float, optional): Convergence tolerance. Defaults to 1e-5. 25 | 26 | Returns: 27 | np.ndarray: Geometric median value 28 | """ 29 | y = np.mean(X, 0) 30 | 31 | while True: 32 | D = cdist(X, [y]) 33 | nonzeros = (D != 0)[:, 0] 34 | 35 | Dinv = 1 / D[nonzeros] 36 | Dinvs = np.sum(Dinv) 37 | W = Dinv / Dinvs 38 | T = np.sum(W * X[nonzeros], 0) 39 | 40 | num_zeros = len(X) - np.sum(nonzeros) 41 | if num_zeros == 0: 42 | y1 = T 43 | elif num_zeros == len(X): 44 | return y 45 | else: 46 | R = (T - y) * Dinvs 47 | r = np.linalg.norm(R) 48 | rinv = 0 if r == 0 else num_zeros / r 49 | y1 = max(0, 1 - rinv) * T + min(1, rinv) * y 50 | 51 | if euclidean(y, y1) < eps: 52 | return y1 53 | 54 | y = y1 -------------------------------------------------------------------------------- /rsc/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/rsc/inference/__init__.py -------------------------------------------------------------------------------- /rsc/inference/fetch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | from osgeo import ogr, osr 5 | import PIL.Image 6 | import PIL.ImageDraw 7 | 8 | from rsc.common.utils import imread_geotransform, imread_srs, imread, map_to_pix 9 | 10 | ogr.UseExceptions() 11 | 12 | # Get ESPG:4326 reference SRS 13 | srs_ref = osr.SpatialReference() 14 | srs_ref.ImportFromEPSG(4326) 15 | srs_ref.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER) 16 | 17 | 18 | def fetch(img_path, x1, y1, x2, y2, wkt): 19 | 20 | # Fetch the tile we need 21 | h, w = y2 - y1, x2 - x1 22 | im = imread(str(img_path), x_off=x1, y_off=y1, w=w, h=h) 23 | xform = imread_geotransform(str(img_path), x_off=x1, y_off=y1) 24 | 25 | # Get the spatial reference (they change across the tiles) 26 | srs = osr.SpatialReference() 27 | srs.ImportFromProj4(imread_srs(str(img_path))) 28 | srs.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER) 29 | im_trans = osr.CoordinateTransformation(srs_ref, srs) 30 | 31 | # Get list of points from WKT 32 | geom: ogr.Geometry = ogr.CreateGeometryFromWkt(wkt) 33 | geom.Transform(im_trans) 34 | pts = np.array( 35 | [geom.GetPoint_2D(idx) for idx in range(geom.GetPointCount())]) 36 | 37 | # Convert to image-space x, y 38 | ix, iy = map_to_pix(list(xform), pts[:, 0], pts[:, 1]) 39 | 40 | # Create a new image of the same shape, and draw a line 41 | # to create a mask 42 | mask_pil = PIL.Image.new('L', ((w, h)), color=0) 43 | d = PIL.ImageDraw.Draw(mask_pil) 44 | d.line( 45 | [(x, y) for x, y in zip(ix, iy)], # type: ignore 46 | fill=255, 47 | width=2, 48 | joint="curve") 49 | mask = np.array(mask_pil)[:, :, np.newaxis] 50 | 51 | return np.concatenate((im, mask), axis=-1) -------------------------------------------------------------------------------- /rsc/inference/mass_inference_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import sqlite3 4 | import pathlib 5 | from typing import Any 6 | import pandas as pd 7 | import numpy as np 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | import torch.multiprocessing 13 | 14 | from .fetch import fetch 15 | 16 | torch.multiprocessing.set_sharing_strategy('file_system') 17 | 18 | IMAGERY_PATH = pathlib.Path('/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019') 19 | assert IMAGERY_PATH.is_dir() 20 | 21 | 22 | class MassInferenceDataset(Dataset): 23 | def __init__(self, 24 | sqlite_path: pathlib.Path, 25 | transform, 26 | n_channels=4, 27 | limit=-1): 28 | 29 | # Connect to sqlite database and get all truthed roadways with no obscuration that are valid 30 | with sqlite3.connect('file:%s?mode=ro' % str(sqlite_path.resolve()), 31 | uri=True) as con: 32 | self.df = pd.read_sql('SELECT * FROM features;', 33 | con).set_index('osm_id') 34 | 35 | # Load dataframe, interpret length 36 | self.n_idxs = len(self.df) if limit == -1 else min(limit, len(self.df)) 37 | 38 | # Number of channels in the image 39 | self.n_channels = n_channels 40 | 41 | # Transformation object 42 | self.transform = transform 43 | 44 | def __len__(self): 45 | return self.n_idxs 46 | 47 | def __getitem__(self, idx): 48 | 49 | # Get row 50 | row: Any = self.df.iloc[idx] # type: ignore 51 | 52 | x1, y1, x2, y2 = [row[e].item() for e in ('x1', 'y1', 'x2', 'y2')] 53 | 54 | # Mask 55 | im = fetch(IMAGERY_PATH / row['img'], x1, y1, x2, y2, row['wkt']) 56 | 57 | # Concat image and masks for output 58 | x = self.transform(im) 59 | 60 | return row.name, x 61 | -------------------------------------------------------------------------------- /rsc/osm/README.md: -------------------------------------------------------------------------------- 1 | # OSM Helper Module 2 | 3 | Helper module to aid with working with OpenStreetMap data, with a focus on drivable road networks. 4 | 5 | ## License 6 | [MIT](https://choosealicense.com/licenses/mit/) © 2024 Jonathan Dalrymple -------------------------------------------------------------------------------- /rsc/osm/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import logging 6 | import functools 7 | 8 | # Configure logger 9 | _logger = logging.getLogger(__name__) 10 | _logger.setLevel(logging.DEBUG) # configure log level here 11 | _logger_handlers = [logging.StreamHandler(sys.stdout)] 12 | _logger_formatter = logging.Formatter( 13 | r'%(asctime)-15s %(levelname)s [%(module)s] %(message)s') 14 | _logger.handlers.clear() 15 | for h in _logger_handlers: 16 | h.setFormatter(_logger_formatter) 17 | _logger.addHandler(h) 18 | 19 | def get_logger(): 20 | """ Fetch the module logger """ 21 | return _logger 22 | 23 | def gdal_required(is_available): 24 | """ Decorator function to check if GDAL was imported """ 25 | def _decorator(func): 26 | @functools.wraps(func) 27 | def wrapper(*args, **kwargs): 28 | if not is_available: 29 | raise RuntimeError('OGR is required for this function / method.') 30 | return func(*args, **kwargs) 31 | return wrapper 32 | return _decorator 33 | 34 | # Exposed classes to user 35 | from .osm_element import OSMElement, OSMNode, OSMWay 36 | from .osm_element_factory import OSMElementFactory 37 | from .osm_network import OSMNetwork 38 | from . import overpass_api -------------------------------------------------------------------------------- /rsc/osm/osm_element.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ Overpass API Query Implementation """ 4 | 5 | from __future__ import annotations 6 | 7 | import json 8 | from typing import List, Union, TypeVar, Type 9 | import warnings 10 | from abc import ABC, abstractmethod 11 | import xml.etree.ElementTree as ET 12 | 13 | _gdal_available = False 14 | try: 15 | from osgeo import ogr 16 | ogr.UseExceptions() 17 | _gdal_available = True 18 | except ModuleNotFoundError: 19 | warnings.warn('Could not find GDAL. Some methods may not be available.', ImportWarning) 20 | 21 | from . import gdal_required 22 | 23 | # Generic type for abstract classmethods of OSMElement 24 | T = TypeVar('T', bound='OSMElement') 25 | 26 | class OSMElement(ABC): 27 | """ Abstract class to describe an OSM element (node, way, etc.) """ 28 | __slots__ = ['id', 'tags'] 29 | 30 | TYPE = '' # holds OSM type string ('node', 'way', etc.) 31 | 32 | def __init__(self, **kwargs): 33 | """ Default constructor: sets attributes from kwargs """ 34 | # Set id and tags (universal to every OSM element) 35 | # NOTE: it's okay if tags is empty, which is why it 36 | # doesn't get the fancy attribute wrapper 37 | self._set_attr_from_kwargs(kwargs, 'id', int, default=None) 38 | self.tags = kwargs.get('tags', {}) 39 | 40 | def _set_attr_from_kwargs(self, 41 | kwargs: dict, 42 | attr_key: str, 43 | dt: type, 44 | default=None): 45 | """ Special way to set an attribute that throws an error if the default is used """ 46 | try: 47 | self.__setattr__(attr_key, dt(kwargs[attr_key])) 48 | except (TypeError, ValueError): 49 | warnings.warn( 50 | f'Cannot cast input for \'{attr_key}\' to {dt}! Setting to {repr(default)}', 51 | RuntimeWarning) 52 | self.__setattr__(attr_key, default) 53 | except KeyError: 54 | warnings.warn( 55 | f'No {attr_key} specified in constructor! Setting to {repr(default)}', 56 | RuntimeWarning) 57 | self.__setattr__(attr_key, default) 58 | 59 | @staticmethod 60 | def _parse_xml_osm_tag(t: ET.Element) -> dict: 61 | """ Parse an XML OSM tag element into a dictionary """ 62 | assert t.tag == 'tag' 63 | d = t.attrib 64 | if 'k' not in d or 'v' not in d: 65 | warnings.warn(f'Tag element empty! {d}', RuntimeWarning) 66 | return {} 67 | return {d['k']: d['v']} 68 | 69 | @staticmethod 70 | def _create_xml_osm_tag(t: dict) -> List[ET.Element]: 71 | """ Create a list of ETree Element objects from a tags dictionary """ 72 | el_list = [] # type: List[ET.Element] 73 | for k, v in t.items(): 74 | el = ET.Element('tag') 75 | el.attrib['k'] = str(k) 76 | el.attrib['v'] = str(v) 77 | el_list.append(el) 78 | return el_list 79 | 80 | def __str__(self): 81 | return f'<{self.__class__.__name__}; id: {self.id}; tags: {len(self.tags)}>' 82 | 83 | def __repr__(self): 84 | return self.__str__() 85 | 86 | @abstractmethod 87 | def to_json_dict(self) -> dict: 88 | """ Export this element to a JSON dict (Python dictionary) """ 89 | return {'type': self.TYPE, 'id': self.id} 90 | 91 | @abstractmethod 92 | def to_xml(self) -> ET.Element: 93 | """ Export this element to an XML string """ 94 | pass 95 | 96 | def to_xml_str(self) -> str: 97 | """ Export this element to an XML string """ 98 | return ET.tostring(self.to_xml(), encoding='utf-8').decode('utf-8') 99 | 100 | @classmethod 101 | @abstractmethod 102 | def from_json_dict(cls: Type[T], json_dict: dict) -> T: 103 | """ Create this element from a JSON dict (Python dictionary) """ 104 | pass 105 | 106 | def to_json(self) -> str: 107 | """ Export this element to an JSON string """ 108 | return json.dumps(self.to_json_dict()) 109 | 110 | @classmethod 111 | def from_json(cls: Type[T], json_str: str) -> T: 112 | """ Create this element from a JSON string """ 113 | return cls.from_json_dict(json.loads(json_str)) 114 | 115 | @classmethod 116 | @abstractmethod 117 | def from_xml(cls: Type[T], tree: ET.Element) -> T: 118 | """ Create this element from an XML Element """ 119 | pass 120 | 121 | @classmethod 122 | def from_xml_str(cls: Type[T], xml_str: str) -> T: 123 | """ Create this element from an XML string """ 124 | return cls.from_xml(ET.fromstring(xml_str)) 125 | 126 | 127 | class OSMNode(OSMElement): 128 | """ Class to describe an OSM node object """ 129 | __slots__ = ['lat', 'lon'] 130 | 131 | TYPE = 'node' 132 | 133 | def __init__(self, **kwargs): 134 | super().__init__(**kwargs) 135 | 136 | # Set latitude and longitude 137 | self._set_attr_from_kwargs(kwargs, 'lat', float, default=0.0) 138 | self._set_attr_from_kwargs(kwargs, 'lon', float, default=0.0) 139 | 140 | def __str__(self): 141 | s = super().__str__()[:-1] 142 | s += f'; lat: {self.lat:.3f}; lon: {self.lon:.3f}>' 143 | return s 144 | 145 | @gdal_required(_gdal_available) 146 | def to_ogr_geom(self) -> ogr.Geometry: 147 | """ 148 | Convert this node to an OGR geometry object 149 | 150 | Note: 151 | Requires GDAL / OGR 152 | 153 | Returns: 154 | ogr.Geometry: OGR Geometry for node. 155 | """ 156 | pt = ogr.Geometry(ogr.wkbPoint) 157 | pt.AddPoint_2D(self.lon, self.lat) 158 | return pt 159 | 160 | @gdal_required(_gdal_available) 161 | def to_wkt(self) -> str: 162 | """ 163 | Get the WKT representation for this node. 164 | 165 | Note: 166 | Requires GDAL / OGR 167 | 168 | Returns: 169 | str: WKT representation for point 170 | """ 171 | return self.to_ogr_geom().ExportToWkt() 172 | 173 | def to_json_dict(self) -> dict: 174 | """ Export this element to a JSON dict (Python dictionary) """ 175 | d = super().to_json_dict() 176 | d['lat'] = float(self.lat) 177 | d['lon'] = float(self.lon) 178 | if len(self.tags): 179 | d['tags'] = {} 180 | for k, v in self.tags.items(): 181 | d['tags'][k] = str(v) 182 | return d 183 | 184 | def to_xml(self) -> ET.Element: 185 | """ Export this element to an XML string """ 186 | # Create OSM XML element 187 | el = ET.Element('node') 188 | 189 | # Add id if it exists 190 | if self.id is not None: 191 | el.attrib['id'] = str(self.id) 192 | 193 | # Add any tags 194 | for tag_element in self._create_xml_osm_tag(self.tags): 195 | el.append(tag_element) 196 | 197 | # Add any remaining attributes 198 | el.attrib['lat'] = str(self.lat) 199 | el.attrib['lon'] = str(self.lon) 200 | 201 | # Return as XML Element 202 | return el 203 | 204 | @classmethod 205 | def from_json_dict(cls: Type[T], json_dict: dict) -> T: 206 | """ Create this element from a JSON dict (Python dictionary) """ 207 | # Trick since the OSM keys match the attribute names exactly 208 | # i.e. 'id', 'lat', 'lon' 209 | # If this wasn't the case extra processing would be necessary 210 | return cls(**json_dict) 211 | 212 | @classmethod 213 | def from_xml(cls: Type[T], tree: ET.Element) -> T: 214 | """ Create this element from an XML string """ 215 | # Assert OSM element type 216 | assert tree.tag == cls.TYPE 217 | 218 | # Trick since the OSM keys match the attribute names exactly 219 | # i.e. 'id', 'lat', 'lon' 220 | # If this wasn't the case extra processing would be necessary 221 | kwargs = tree.attrib.copy() 222 | 223 | # Parse any tags 224 | tags = {} 225 | for t in tree.findall('tag'): 226 | tags.update(cls._parse_xml_osm_tag(t)) 227 | kwargs['tags'] = tags # type: ignore 228 | 229 | return cls(**kwargs) 230 | 231 | 232 | class OSMWay(OSMElement): 233 | """ Class to describe an OSM way object """ 234 | __slots__ = ['nodes'] 235 | 236 | TYPE = 'way' 237 | 238 | def __init__(self, **kwargs): 239 | super().__init__(**kwargs) 240 | 241 | # Set nodes 242 | self._set_attr_from_kwargs(kwargs, 'nodes', list, default=[]) 243 | 244 | def __len__(self): 245 | return self.nodes.__len__() 246 | 247 | def __str__(self): 248 | s = super().__str__()[:-1] 249 | s += f'; nodes: {self.nodes.__len__()}>' 250 | return s 251 | 252 | @staticmethod 253 | def _parse_xml_osm_node(n: ET.Element) -> Union[int, None]: 254 | """ Parse an XML OSM tag element into a dictionary """ 255 | assert n.tag == 'nd' 256 | d = n.attrib 257 | try: 258 | return int(d['ref']) 259 | except (TypeError, ValueError): 260 | warnings.warn( 261 | f'Node element id cannit be casted to int! Setting to None. {d["ref"]}', 262 | RuntimeWarning) 263 | return None 264 | except KeyError: 265 | warnings.warn(f'Node element empty! {d}', RuntimeWarning) 266 | return None 267 | 268 | @staticmethod 269 | def _create_xml_osm_node(n: int) -> ET.Element: 270 | """ Create a list of ETree Element objects from a node id """ 271 | el = ET.Element('nd') 272 | el.attrib['ref'] = str(n) 273 | return el 274 | 275 | def to_json_dict(self) -> dict: 276 | """ Export this element to a JSON dict (Python dictionary) """ 277 | d = super().to_json_dict() 278 | d['nodes'] = list(self.nodes) 279 | if len(self.tags): 280 | d['tags'] = {} 281 | for k, v in self.tags.items(): 282 | d['tags'][k] = str(v) 283 | return d 284 | 285 | def to_xml(self) -> ET.Element: 286 | """ Export this element to an XML Element """ 287 | # Create OSM XML element 288 | el = ET.Element('way') 289 | 290 | # Add id if it exists 291 | if self.id is not None: 292 | el.attrib['id'] = str(self.id) 293 | 294 | # Add any nodes 295 | for node in self.nodes: # type: ignore 296 | el.append(self._create_xml_osm_node(node)) 297 | 298 | # Add any tags 299 | for tag_element in self._create_xml_osm_tag(self.tags): 300 | el.append(tag_element) 301 | 302 | # Return as an element 303 | return el 304 | 305 | @classmethod 306 | def from_json_dict(cls: Type[T], json_dict: dict) -> T: 307 | """ Create this element from a JSON dict (Python dictionary) """ 308 | # Trick since the OSM keys match the attribute names exactly 309 | # If this wasn't the case extra processing would be necessary 310 | return cls(**json_dict) 311 | 312 | @classmethod 313 | def from_xml(cls, tree: ET.Element) -> OSMWay: 314 | """ Create this element from an XML Element """ 315 | # Assert OSM element type 316 | assert tree.tag == cls.TYPE 317 | 318 | # Kwargs to pass to class constructor 319 | kwargs = {} 320 | 321 | # Parse id 322 | if 'id' in tree.attrib: 323 | kwargs['id'] = tree.attrib['id'] 324 | 325 | # Parse any tags 326 | nodes = [] 327 | for nd in tree.findall('nd'): 328 | node_id = cls._parse_xml_osm_node(nd) 329 | if node_id is not None: 330 | nodes.append(node_id) 331 | kwargs['nodes'] = nodes # type: ignore 332 | 333 | # Parse any tags 334 | tags = {} 335 | for t in tree.findall('tag'): 336 | tags.update(cls._parse_xml_osm_tag(t)) 337 | kwargs['tags'] = tags # type: ignore 338 | 339 | return cls(**kwargs) 340 | -------------------------------------------------------------------------------- /rsc/osm/osm_element_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import xml.etree.ElementTree as ET 5 | 6 | from .osm_element import OSMElement, OSMNode, OSMWay 7 | 8 | 9 | class OSMElementFactory: 10 | """ Factory class for OSM elements """ 11 | 12 | @staticmethod 13 | def from_json_dict(d: dict) -> OSMElement: 14 | # Read OSM 'type' key 15 | osm_type = d.get('type') 16 | 17 | # Create class 18 | if osm_type == 'node': 19 | return OSMNode.from_json_dict(d) 20 | elif osm_type == 'way': 21 | return OSMWay.from_json_dict(d) 22 | else: 23 | raise ValueError('Unknown OSM type: %s' % osm_type) 24 | 25 | @staticmethod 26 | def from_xml(el: ET.Element) -> OSMElement: 27 | # Read OSM 'type' key 28 | osm_type = el.tag 29 | 30 | # Create class 31 | if osm_type == 'node': 32 | return OSMNode.from_xml(el) 33 | elif osm_type == 'way': 34 | return OSMWay.from_xml(el) 35 | else: 36 | raise ValueError('Unknown XML tag: %s' % osm_type) 37 | -------------------------------------------------------------------------------- /rsc/osm/osm_overpass_api.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ Classes to aid in OSM Overpass API Queries and Results """ 4 | 5 | from __future__ import annotations 6 | 7 | import json 8 | import pathlib 9 | import textwrap 10 | from abc import ABC, abstractmethod 11 | from typing import Iterable, Union 12 | import requests 13 | 14 | from osgeo import ogr 15 | 16 | 17 | class OSMOverpassQuery(ABC): 18 | """ Abstract class to perform an OSM Overpass API query """ 19 | 20 | __slots__ = ['_endpoint', '_timeout', '_format', '_poly'] 21 | 22 | def __init__(self, **kwargs): 23 | self._endpoint = kwargs.get( 24 | 'endpoint', 'https://lz4.overpass-api.de/api/interpreter') 25 | self._timeout = kwargs.get('timeout', 180) 26 | self._format = kwargs.get('format', 'json') 27 | self._poly = None 28 | 29 | def set_poly_from_list(self, poly_list: Iterable[tuple[float, 30 | float]]) -> None: 31 | """ 32 | Set the polygon query boundary from a list of coordinates 33 | 34 | Args: 35 | poly_list (Iterable[tuple[float, float]]): List of floats in (lon, lat) format 36 | """ 37 | self._poly = [] 38 | for lon, lat in poly_list: 39 | self._poly.append((lon, lat)) 40 | 41 | def set_poly_from_bbox(self, lat0: float, lon0: float, lat1: float, 42 | lon1: float) -> None: 43 | """ 44 | Set Overpass Query Filter polygon by bounding box 45 | 46 | Args: 47 | lat0 (float): First latitude (can be min or max) 48 | lon0 (float): First longitude (can be min or max) 49 | lat1 (float): Second latitude (can be min or max) 50 | lon1 (float): Second longitude (can be min or max) 51 | """ 52 | # Find max, min latitude and longitude 53 | lat_max = max((lat0, lat1)) 54 | lat_min = min((lat0, lat1)) 55 | lon_max = max((lon0, lon1)) 56 | lon_min = min((lon0, lon1)) 57 | 58 | # Set polygon 59 | self._poly = [(lon_min, lat_max), (lon_max, lat_max), 60 | (lon_max, lat_min), (lon_min, lat_min), 61 | (lon_min, lat_max)] 62 | 63 | def set_poly_from_wkt(self, wkt_str: str) -> None: 64 | """ 65 | Set Overpass Query Filter polygon by WKT string 66 | 67 | Args: 68 | wkt_str (str): Input polygon in WKT format 69 | 70 | Raises: 71 | ValueError: If the input WKT string is not of a polygon 72 | """ 73 | # Load WKT string into JSON dict 74 | geojson_dict = json.loads( 75 | ogr.CreateGeometryFromWkt(wkt_str).ExportToJson()) 76 | 77 | # Sanity check geometry type 78 | if geojson_dict['type'] != 'Polygon': 79 | raise ValueError('Input WKT string must be a polygon! Got: %s' % 80 | geojson_dict['type']) 81 | 82 | # Set polygon (0 index is for linearRing inside polygon) 83 | self._poly = geojson_dict['coordinates'][0] 84 | 85 | @property 86 | def _poly_query_str(self) -> str: 87 | if self._poly is None: 88 | raise ValueError('Polygon not set!') 89 | return ' '.join([ 90 | ' '.join((f'{lat:.6f}', f'{lon:.6f}')) for lon, lat in self._poly 91 | ]) 92 | 93 | @property 94 | @abstractmethod 95 | def _query_str(self) -> str: 96 | return '' 97 | 98 | @property 99 | def endpoint(self) -> str: 100 | return self._endpoint 101 | 102 | def set_endpoint(self, endpoint: str) -> None: 103 | """ Set the Overpass API Endpoint """ 104 | self._endpoint = endpoint 105 | 106 | def _perform_query(self) -> requests.models.Response: 107 | """ Perform an OSM Overpass API Request! """ 108 | query_str = textwrap.dedent(self._query_str).replace('\n', '') 109 | result = requests.get(self.endpoint, params={'data': query_str}) 110 | result.raise_for_status() 111 | return result 112 | 113 | def perform_query(self) -> OSMOverpassResult: 114 | """ Perform an OSM Overpass API Request! """ 115 | return OSMOverpassResult(self._perform_query()) 116 | 117 | 118 | class OSMOverpassResult: 119 | """ Container class for OSM Overpass result data. Will need to be subclassed to be useful. """ 120 | __slots__ = ['_result'] 121 | 122 | def __init__(self, result: requests.models.Response): 123 | self._result = result 124 | 125 | def to_file(self, output_path: Union[pathlib.Path, str]) -> None: 126 | """ 127 | Have the query result to file. The file must be of the same format 128 | as the query result. 129 | 130 | Args: 131 | output_path (Union[pathlib.Path, str]): Output file path 132 | 133 | Raises: 134 | ValueError: if format is not recognized 135 | """ 136 | # Get output path, sanity check dir 137 | output_path = pathlib.Path(output_path) 138 | assert output_path.parent.exists() 139 | 140 | # Assert output path suffix is valid 141 | this_format = self.format 142 | if 'json' in this_format: 143 | assert output_path.suffix.lower() == '.json' 144 | elif 'xml' in this_format: 145 | assert output_path.suffix.lower() in ('.xml', '.osm') 146 | else: 147 | raise ValueError('Unknown format: %s. Cannot export to file.' % 148 | str(this_format)) 149 | 150 | # Save to file! 151 | with open(output_path, 'wb') as f: 152 | f.write(self._result.content) 153 | 154 | @property 155 | def format(self) -> str: 156 | if 'Content-Type' in self._result.headers: 157 | return self._result.headers['Content-Type'].split('/')[-1] 158 | raise ValueError( 159 | 'Unknown format! Content-Type not found in result header.') 160 | -------------------------------------------------------------------------------- /rsc/osm/overpass_api/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | from .osm_overpass_api import OSMOverpassQuery, OSMOverpassResult -------------------------------------------------------------------------------- /rsc/osm/overpass_api/osm_overpass_api.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ Classes to aid in OSM Overpass API Queries and Results """ 4 | 5 | from __future__ import annotations 6 | 7 | import json 8 | import pathlib 9 | import textwrap 10 | from abc import ABC, abstractmethod 11 | from typing import Iterable, Union 12 | import requests 13 | 14 | from osgeo import ogr 15 | 16 | 17 | class OSMOverpassQuery(ABC): 18 | """ Abstract class to perform an OSM Overpass API query """ 19 | 20 | __slots__ = ['_endpoint', '_timeout', '_format', '_poly'] 21 | 22 | def __init__(self, **kwargs): 23 | self._endpoint = kwargs.get( 24 | 'endpoint', 'https://lz4.overpass-api.de/api/interpreter') 25 | self._timeout = kwargs.get('timeout', 180) 26 | self._format = kwargs.get('format', 'json') 27 | self._poly = None 28 | 29 | def set_poly_from_list(self, poly_list: Iterable[tuple[float, 30 | float]]) -> None: 31 | """ 32 | Set the polygon query boundary from a list of coordinates 33 | 34 | Args: 35 | poly_list (Iterable[tuple[float, float]]): List of floats in (lon, lat) format 36 | """ 37 | self._poly = [] 38 | for lon, lat in poly_list: 39 | self._poly.append((lon, lat)) 40 | 41 | def set_poly_from_bbox(self, lat0: float, lon0: float, lat1: float, 42 | lon1: float) -> None: 43 | """ 44 | Set Overpass Query Filter polygon by bounding box 45 | 46 | Args: 47 | lat0 (float): First latitude (can be min or max) 48 | lon0 (float): First longitude (can be min or max) 49 | lat1 (float): Second latitude (can be min or max) 50 | lon1 (float): Second longitude (can be min or max) 51 | """ 52 | # Find max, min latitude and longitude 53 | lat_max = max((lat0, lat1)) 54 | lat_min = min((lat0, lat1)) 55 | lon_max = max((lon0, lon1)) 56 | lon_min = min((lon0, lon1)) 57 | 58 | # Set polygon 59 | self._poly = [(lon_min, lat_max), (lon_max, lat_max), 60 | (lon_max, lat_min), (lon_min, lat_min), 61 | (lon_min, lat_max)] 62 | 63 | def set_poly_from_wkt(self, wkt_str: str) -> None: 64 | """ 65 | Set Overpass Query Filter polygon by WKT string 66 | 67 | Args: 68 | wkt_str (str): Input polygon in WKT format 69 | 70 | Raises: 71 | ValueError: If the input WKT string is not of a polygon 72 | """ 73 | # Load WKT string into JSON dict 74 | geojson_dict = json.loads( 75 | ogr.CreateGeometryFromWkt(wkt_str).ExportToJson()) 76 | 77 | # Sanity check geometry type 78 | if geojson_dict['type'] != 'Polygon': 79 | raise ValueError('Input WKT string must be a polygon! Got: %s' % 80 | geojson_dict['type']) 81 | 82 | # Set polygon (0 index is for linearRing inside polygon) 83 | self._poly = geojson_dict['coordinates'][0] 84 | 85 | @property 86 | def _poly_query_str(self) -> str: 87 | if self._poly is None: 88 | raise ValueError('Polygon not set!') 89 | return ' '.join([ 90 | ' '.join((f'{lat:.6f}', f'{lon:.6f}')) for lon, lat in self._poly 91 | ]) 92 | 93 | @property 94 | @abstractmethod 95 | def _query_str(self) -> str: 96 | return '' 97 | 98 | @property 99 | def endpoint(self) -> str: 100 | return self._endpoint 101 | 102 | def set_endpoint(self, endpoint: str) -> None: 103 | """ Set the Overpass API Endpoint """ 104 | self._endpoint = endpoint 105 | 106 | def _perform_query(self) -> requests.models.Response: 107 | """ Perform an OSM Overpass API Request! """ 108 | query_str = textwrap.dedent(self._query_str).replace('\n', '') 109 | result = requests.get(self.endpoint, params={'data': query_str}) 110 | result.raise_for_status() 111 | return result 112 | 113 | def perform_query(self) -> OSMOverpassResult: 114 | """ Perform an OSM Overpass API Request! """ 115 | return OSMOverpassResult(self._perform_query()) 116 | 117 | 118 | class OSMOverpassResult: 119 | """ Container class for OSM Overpass result data. Will need to be subclassed to be useful. """ 120 | __slots__ = ['_result'] 121 | 122 | def __init__(self, result: requests.models.Response): 123 | self._result = result 124 | 125 | def to_file(self, output_path: Union[pathlib.Path, str]) -> None: 126 | """ 127 | Have the query result to file. The file must be of the same format 128 | as the query result. 129 | 130 | Args: 131 | output_path (Union[pathlib.Path, str]): Output file path 132 | 133 | Raises: 134 | ValueError: if format is not recognized 135 | """ 136 | # Get output path, sanity check dir 137 | output_path = pathlib.Path(output_path) 138 | assert output_path.parent.exists() 139 | 140 | # Assert output path suffix is valid 141 | this_format = self.format 142 | if 'json' in this_format: 143 | assert output_path.suffix.lower() == '.json' 144 | elif 'xml' in this_format: 145 | assert output_path.suffix.lower() in ('.xml', '.osm') 146 | else: 147 | raise ValueError('Unknown format: %s. Cannot export to file.' % 148 | str(this_format)) 149 | 150 | # Save to file! 151 | with open(output_path, 'wb') as f: 152 | f.write(self._result.content) 153 | 154 | @property 155 | def format(self) -> str: 156 | if 'Content-Type' in self._result.headers: 157 | return self._result.headers['Content-Type'].split('/')[-1] 158 | raise ValueError( 159 | 'Unknown format! Content-Type not found in result header.') 160 | -------------------------------------------------------------------------------- /rsc/osm/overpass_api/road_network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import annotations 5 | import tempfile 6 | import pathlib 7 | 8 | from ..osm_network import OSMNetwork 9 | from .osm_overpass_api import OSMOverpassQuery, OSMOverpassResult 10 | 11 | 12 | class OSMRoadNetworkOverpassQuery(OSMOverpassQuery): 13 | """ Custom OSM Overpass API query for (hopefully) drivable road networks """ 14 | 15 | __slots__ = ['_highway_tags'] 16 | 17 | DEFAULT_HIGHWAY_TAGS = [ 18 | 'motorway', 'motorway_link', 'motorway_junction', 'trunk', 19 | 'trunk_link', 'primary', 'primary_link', 'secondary', 'secondary_link', 20 | 'tertiary', 'tertiary_link', 'unclassified', 'residential' 21 | ] 22 | 23 | def __init__(self, **kwargs): 24 | super().__init__(**kwargs) 25 | self._highway_tags = kwargs.get('highway_tags', 26 | self.DEFAULT_HIGHWAY_TAGS) 27 | 28 | def perform_query(self) -> OSMRoadNetworkOverpassResult: 29 | """ Perform an OSM Overpass API Request! """ 30 | return OSMRoadNetworkOverpassResult(self._perform_query()) 31 | 32 | @property 33 | def _query_str(self) -> str: 34 | # NOTE: OSMNX uses instead 35 | # ["highway"!~"abandoned|bridleway|bus_guideway|construction|corridor|cycleway|elevator|escalator|footway|path|pedestrian|planned|platform|proposed|raceway|service|steps|track"] 36 | 37 | return f""" 38 | [out:{self._format}] 39 | [timeout:{self._timeout}]; 40 | (way["highway"] 41 | ["area"!~"yes"] 42 | ["access"!~"private"] 43 | ["highway"~"{'|'.join(self._highway_tags)}"] 44 | ["motor_vehicle"!~"no"] 45 | ["motorcar"!~"no"] 46 | ["service"!~"alley|driveway|emergency_access|parking|parking_aisle|private"] 47 | (poly:'{self._poly_query_str}'); 48 | >; 49 | ); 50 | out; 51 | """ 52 | 53 | 54 | class OSMRoadNetworkOverpassResult(OSMOverpassResult): 55 | """ Container class for OSM Overpass result data for (hopefully) drivable road networks """ 56 | 57 | def to_network(self) -> OSMNetwork: 58 | """ Convert result data to an OSM network """ 59 | with tempfile.TemporaryDirectory() as td: 60 | tmp_file = pathlib.Path(td, 'tmp') 61 | this_format = self.format 62 | if 'json' in this_format: 63 | tmp_file = tmp_file.with_suffix('.json') 64 | elif 'xml' in this_format: 65 | tmp_file = tmp_file.with_suffix('.osm') 66 | else: 67 | raise ValueError( 68 | 'Unknown format: \'%s\'! Cannot convert to OSM network.') 69 | 70 | # Save the response content to file 71 | tmp_file.write_bytes(self._result.content) 72 | 73 | # Load the network! 74 | n = OSMNetwork.from_file(tmp_file) 75 | 76 | return n 77 | -------------------------------------------------------------------------------- /rsc/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/rsc/train/__init__.py -------------------------------------------------------------------------------- /rsc/train/color_jitter_nohuesat.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Dict, List, Optional, Tuple, Union, cast 3 | 4 | from torch import Tensor 5 | 6 | from kornia.augmentation import random_generator as rg 7 | from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D 8 | from kornia.enhance import (adjust_brightness_accumulative, 9 | adjust_contrast_with_mean_subtraction) 10 | 11 | 12 | class ColorJitterNoHueSat(IntensityAugmentationBase2D): 13 | r"""Apply a random transformation to the brightness, contrast, saturation and hue of a tensor image. 14 | 15 | This implementation aligns PIL. Hence, the output is close to TorchVision. 16 | 17 | .. image:: _static/img/ColorJitter.png 18 | 19 | Args: 20 | p: probability of applying the transformation. 21 | brightness: The brightness factor to apply. 22 | contrast: The contrast factor to apply. 23 | silence_instantiation_warning: if True, silence the warning at instantiation. 24 | same_on_batch: apply the same transformation across the batch. 25 | keepdim: whether to keep the output shape the same as input (True) or broadcast it 26 | to the batch form (False). 27 | Shape: 28 | - Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`, Optional: :math:`(B, 3, 3)` 29 | - Output: :math:`(B, C, H, W)` 30 | """ 31 | 32 | def __init__( 33 | self, 34 | brightness: Union[Tensor, float, Tuple[float, float], 35 | List[float]] = 0.0, 36 | contrast: Union[Tensor, float, Tuple[float, float], List[float]] = 0.0, 37 | same_on_batch: bool = False, 38 | p: float = 1.0, 39 | keepdim: bool = False, 40 | silence_instantiation_warning: bool = False, 41 | ) -> None: 42 | super().__init__(p=p, 43 | same_on_batch=same_on_batch, 44 | keepdim=keepdim) 45 | 46 | if not silence_instantiation_warning: 47 | warnings.warn( 48 | "`ColorJitter` is now following Torchvision implementation. Old " 49 | "behavior can be retrieved by instantiating `ColorJiggle`.", 50 | category=DeprecationWarning, 51 | ) 52 | 53 | self.brightness = brightness 54 | self.contrast = contrast 55 | self._param_generator = cast( 56 | rg.ColorJitterGenerator, 57 | rg.ColorJitterGenerator(brightness, contrast)) 58 | 59 | def apply_transform(self, 60 | input: Tensor, 61 | params: Dict[str, Tensor], 62 | flags: Dict[str, Any], 63 | transform: Optional[Tensor] = None) -> Tensor: 64 | 65 | transforms = [ 66 | lambda img: adjust_brightness_accumulative( 67 | img, params["brightness_factor"]), 68 | lambda img: adjust_contrast_with_mean_subtraction( 69 | img, params["contrast_factor"]) 70 | ] 71 | 72 | jittered = input 73 | for idx in (0, 1): 74 | t = transforms[idx] 75 | jittered = t(jittered) 76 | 77 | return jittered 78 | -------------------------------------------------------------------------------- /rsc/train/data_augmentation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | from typing import Tuple 4 | import torch 5 | import torch.nn as nn 6 | import kornia 7 | 8 | from .color_jitter_nohuesat import ColorJitterNoHueSat 9 | 10 | 11 | class DataAugmentation(nn.Module): 12 | 13 | def __init__(self, has_nir: bool=True): 14 | super().__init__() 15 | 16 | # Are we augmenting NIR as well? 17 | self.has_nir = has_nir 18 | 19 | # Random positional transformations 20 | self.transform_flip = nn.Sequential( 21 | kornia.augmentation.RandomHorizontalFlip(p=0.5), 22 | kornia.augmentation.RandomVerticalFlip(p=0.5)) 23 | 24 | # Random offset of mask to re-inforce proper segmentation 25 | # when labels may be inaccurate 26 | self.transform_offset = nn.Sequential( 27 | kornia.augmentation.RandomAffine(degrees=(-15, 15), 28 | translate=(0.0625, 0.0625), 29 | p=0.5)) 30 | 31 | # Transform RGB 32 | self.transform_color = nn.Sequential( 33 | kornia.augmentation.RandomPlasmaBrightness(roughness=(0.1, 0.5), 34 | intensity=(0.1, 0.3)), 35 | kornia.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1)) 36 | 37 | # Transform NIR 38 | self.transform_nir = nn.Sequential(ColorJitterNoHueSat(0.1, 0.1)) 39 | 40 | @torch.no_grad() 41 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 42 | 43 | # Apply position transform to image + masks 44 | x_aug = self.transform_flip(x) 45 | 46 | # Break down what all of the channels are 47 | mask_c = slice(4, 5) if self.has_nir else slice(3, 4) 48 | probmask_c = slice(5, 6) if self.has_nir else slice(4, 5) 49 | 50 | # RGB color transformation 51 | im_aug = self.transform_color(x_aug[:, 0:3, ...]) 52 | 53 | if self.has_nir: 54 | # NIR color transformation 55 | im_nir_aug = self.transform_nir(x_aug[:, 3:4, ...]) 56 | 57 | # Combine NIR with RGB image 58 | im_aug = torch.concat((im_aug, im_nir_aug), dim=1) 59 | 60 | # Apply offset transform to mask 61 | mask_aug = self.transform_offset(x_aug[:, mask_c, ...]) 62 | 63 | # Combine into color image + location mask 64 | im_aug = torch.concat((im_aug, mask_aug), dim=1) 65 | 66 | # Extract probmask for training 67 | pm_aug = x_aug[:, probmask_c, :, :] 68 | 69 | return im_aug, pm_aug 70 | -------------------------------------------------------------------------------- /rsc/train/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | from typing import Any 4 | import pandas as pd 5 | import numpy as np 6 | import PIL.Image 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | from torchvision.transforms import CenterCrop 12 | 13 | import torch.multiprocessing 14 | 15 | # Prevents file descriptor errors when doing multiprocessing fetches 16 | torch.multiprocessing.set_sharing_strategy('file_system') 17 | 18 | 19 | class RoadSurfaceDataset(Dataset): 20 | 21 | def __init__(self, 22 | df_path, 23 | transform, 24 | chip_size=224, 25 | n_channels=4, 26 | limit=-1): 27 | 28 | # Load dataframe, interpret lenght 29 | self.df = pd.read_csv(df_path) 30 | self.n_idxs = len(self.df) if limit == -1 else min(limit, len(self.df)) 31 | 32 | # Number of channels in the image 33 | self.n_channels = n_channels 34 | 35 | # Number of classes 36 | self.n_classes = self.df['class_num'].max() + 1 37 | 38 | # Chip size 39 | self.set_chip_size(chip_size) 40 | 41 | # Transformation object 42 | self.transform = transform 43 | 44 | def set_chip_size(self, chip_size: int): 45 | self.cc = CenterCrop(chip_size) 46 | 47 | def __len__(self): 48 | return self.n_idxs 49 | 50 | def __getitem__(self, idx): 51 | 52 | # Get row 53 | row: Any = self.df.iloc[idx] # type: ignore 54 | 55 | # Image 56 | with PIL.Image.open(row.chip_path) as pim: 57 | im = np.array(self.cc(pim)) 58 | tn = im.shape[2] if im.ndim == 3 else 1 59 | if tn > 1: 60 | # Adds support for loading imagery > 61 | # n_channels. In this case we just grab 62 | # the required number of channels 63 | im = im[..., :self.n_channels] 64 | tn = im.shape[2] 65 | if tn > self.n_channels: 66 | print( 67 | f'WARNING: Got {im.shape[2]} channel image but model only has {self.n_channels} dimensions!' 68 | ) 69 | im = im[:, :, :self.n_channels] 70 | 71 | # Mask 72 | with PIL.Image.open(row.mask_path) as pmask: 73 | mask = np.array(self.cc(pmask)) 74 | if mask.ndim == 2: 75 | mask = mask[:, :, np.newaxis] 76 | tn = im.shape[2] 77 | if tn > 1: 78 | mask = mask[:, :, 0][:, :, np.newaxis] 79 | 80 | # Prob mask 81 | with PIL.Image.open(row.probmask_path) as pmask: 82 | probmask = np.array(self.cc(pmask)) 83 | if probmask.ndim == 2: 84 | probmask = probmask[:, :, np.newaxis] 85 | tn = im.shape[2] 86 | if tn > 1: 87 | probmask = probmask[:, :, 0][:, :, np.newaxis] 88 | 89 | # Label (add one to number of classes to account for obscurations) 90 | lbl = [0] * (self.n_classes + 1) 91 | 92 | # Get class idx 93 | c = int(row.class_num) 94 | 95 | # Compute obscuration estimate 96 | obsc = 1 - (mask * (probmask > 127)).sum() / mask.sum() 97 | 98 | # Set labels accordingly 99 | # NOTE: no longer fuzzy 100 | lbl[c] = 1 101 | lbl[-1] = obsc 102 | 103 | # Concat image and masks for output 104 | x = self.transform(np.concatenate((im, mask, probmask), axis=2)) 105 | 106 | return x, torch.Tensor(lbl) 107 | -------------------------------------------------------------------------------- /rsc/train/mcnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as functional 6 | import torchvision 7 | 8 | # NOTE: Patch for MaxUnpool2d needs to be applied 9 | # when exporting to ONNX 10 | from torch.nn import MaxUnpool2d 11 | #from patch import MaxUnpool2d 12 | 13 | 14 | class Freezable: 15 | 16 | def parameters(self): 17 | # This is overridden 18 | return [] 19 | 20 | def freeze(self): 21 | self._set_freeze(True) 22 | 23 | def unfreeze(self): 24 | self._set_freeze(False) 25 | 26 | def _set_freeze(self, v): 27 | req_grad = not bool(v) 28 | for param in self.parameters(): 29 | param.requires_grad = req_grad 30 | 31 | 32 | class FreezableModule(nn.Module, Freezable): 33 | pass 34 | 35 | 36 | class FreezableModuleList(nn.ModuleList, Freezable): 37 | pass 38 | 39 | 40 | class FreezableAdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, Freezable): 41 | pass 42 | 43 | 44 | class FreezableLinear(nn.Linear, Freezable): 45 | pass 46 | 47 | 48 | class DecoderBlock(FreezableModule): 49 | 50 | def __init__(self, in_channels, out_channels, kernel_size=3): 51 | super().__init__() 52 | 53 | self.upconv = nn.ConvTranspose2d(in_channels, out_channels, 2, 2) 54 | self.conv1 = nn.Conv2d(out_channels * 2, out_channels, kernel_size) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size) 57 | 58 | def forward(self, x, encoder_feats=None): 59 | 60 | x = self.upconv(x) 61 | 62 | enc_ftrs = self.crop(encoder_feats, x) 63 | x = torch.concat((x, enc_ftrs), dim=1) 64 | 65 | x = self.conv1(x) 66 | x = self.relu(x) 67 | x = self.conv2(x) 68 | x = self.relu(x) 69 | return x 70 | 71 | def crop(self, enc_ftrs, x): 72 | _, _, h, w = x.shape 73 | enc_ftrs = torchvision.transforms.CenterCrop([h, w])(enc_ftrs) 74 | return enc_ftrs 75 | 76 | 77 | class Resnet18Encoder(FreezableModule): 78 | 79 | def __init__(self, in_channels=3): 80 | super().__init__() 81 | 82 | # Get Resnet18 w/ default weights 83 | self.rnet = torchvision.models.resnet18( 84 | weights=torchvision.models.ResNet18_Weights.DEFAULT) 85 | 86 | if in_channels != 3: 87 | self.rnet.conv1 = nn.Conv2d(in_channels, 88 | 64, 89 | kernel_size=(7, 7), 90 | stride=(2, 2), 91 | padding=(3, 3), 92 | bias=False) 93 | 94 | self.rnet.maxpool.return_indices = True 95 | 96 | # Delete final avg pool / linear layer 97 | del self.rnet.avgpool 98 | del self.rnet.fc 99 | 100 | # Features for cross-connections to decoder 101 | self.feats = [ 102 | torch.Tensor([]), 103 | torch.Tensor([]), 104 | torch.Tensor([]), 105 | torch.Tensor([]) 106 | ] 107 | self.maxpool_idxs = None 108 | 109 | def forward(self, x): 110 | # Building up to stuff 111 | x = self.rnet.conv1(x) 112 | x = self.rnet.bn1(x) 113 | x = self.rnet.relu(x) 114 | x, self.maxpool_idxs = self.rnet.maxpool(x) 115 | 116 | # Now the main layers 117 | self.feats[0] = x 118 | x = self.rnet.layer1(x) 119 | 120 | self.feats[1] = x 121 | x = self.rnet.layer2(x) 122 | 123 | self.feats[2] = x 124 | x = self.rnet.layer3(x) 125 | 126 | self.feats[3] = x 127 | x = self.rnet.layer4(x) 128 | 129 | return x 130 | 131 | 132 | class Resnet18Decoder(FreezableModuleList): 133 | 134 | def __init__(self): 135 | super().__init__() 136 | 137 | self.layer_1 = DecoderBlock(512, 256) 138 | self.layer_2 = DecoderBlock(256, 128) 139 | self.layer_3 = DecoderBlock(128, 64) 140 | self.layer_4 = DecoderBlock(64, 64, kernel_size=1) 141 | 142 | self.unpool = MaxUnpool2d(kernel_size=3, stride=2, padding=1) 143 | self.final_1 = nn.Conv2d(64, 2, 3) 144 | self.relu = nn.ReLU(inplace=True) 145 | self.final_2 = nn.Conv2d(2, 2, 3) 146 | self.softmax = nn.Softmax(dim=1) 147 | 148 | def forward(self, x, encoder_feats, maxpool_idxs): 149 | 150 | x = self.layer_1(x, encoder_feats[3]) 151 | x = self.layer_2(x, encoder_feats[2]) 152 | x = self.layer_3(x, encoder_feats[1]) 153 | x = self.layer_4(x, encoder_feats[0]) 154 | x = self.unpool(x, maxpool_idxs, output_size=(112, 112)) 155 | 156 | x = self.final_1(x) 157 | x = self.relu(x) 158 | x = self.final_2(x) 159 | x = functional.interpolate(x, (224, 224)) 160 | x = self.softmax(x) 161 | 162 | return x 163 | 164 | 165 | class MaskCNN(nn.Module): 166 | 167 | def __init__(self, num_classes=2, num_channels=5): 168 | super().__init__() 169 | 170 | # Segmentation stage 171 | self.encoder = Resnet18Encoder(in_channels=num_channels) 172 | self.decoder = Resnet18Decoder() 173 | 174 | # Classification stage 175 | self.encoder2 = Resnet18Encoder(in_channels=num_channels) 176 | self.avgpool = FreezableAdaptiveAvgPool2d(output_size=(1, 1)) 177 | self.fc = FreezableLinear(512, num_classes, bias=True) 178 | 179 | def forward(self, x): 180 | 181 | # Image -> Features 182 | y = self.encoder(x) 183 | 184 | # Features -> Mask 185 | y = self.decoder(y, self.encoder.feats, self.encoder.maxpool_idxs) 186 | y = y[:, 0:1, ...] 187 | 188 | # Adjust image from mask: we only fetch the former 189 | # NOTE: - RGB: image is (0, 1, 2). Input mask is (3,), 190 | # - RGB + NIR image is (0, 1, 2, 3). Input mask is (4,), 191 | # - In both cases this is the image is :-1, mask is -1 192 | # NOTE: the paper recommends multiplication, but in this case 193 | # concat-ing the segmentation mask seems to produce better results 194 | # x = torch.multiply(x[:, :-1, ...], y) 195 | x = torch.concat((x[:, :-1, ...], y), dim=1) 196 | 197 | # Updated Image -> Features 198 | x = self.encoder2(x) 199 | x = self.avgpool(x) 200 | x = torch.flatten(x, 1) 201 | z = self.fc(x) 202 | 203 | return y, z -------------------------------------------------------------------------------- /rsc/train/mcnn_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as functional 7 | 8 | from .rsc_hxe_loss import RSCHXELoss 9 | 10 | 11 | class MCNNLoss(nn.Module): 12 | """ Combined MaskCNN loss function """ 13 | 14 | def __init__(self, top_lv_map, class_weights, seg_k, ob_k): 15 | super().__init__() 16 | 17 | # Inputs 18 | self.top_lv_map = torch.IntTensor(top_lv_map).cuda() 19 | self.class_weights = torch.Tensor(class_weights).float().cuda() 20 | self.seg_k = seg_k 21 | self.ob_k = ob_k 22 | 23 | # Loss functions 24 | self.hxe_loss = RSCHXELoss(self.top_lv_map, self.class_weights) 25 | self.o_loss = torch.nn.BCEWithLogitsLoss(reduction='mean') 26 | 27 | self.seg_loss = 0. 28 | self.cl_loss = 0. 29 | self.ob_loss = 0. 30 | self.stage = 0 31 | 32 | def forward(self, y_hat: torch.Tensor, y: torch.Tensor, 33 | z_hat: torch.Tensor, z: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Compute the loss given the model's segmentation and classification results. 36 | 37 | Segmentation loss is DICE + BCE loss, inspired from: 38 | https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook#BCE-Dice-Loss 39 | 40 | Args: 41 | y_hat (torch.Tensor): Segmentation model result 42 | y (torch.Tensor): Segmentation truth 43 | z_hat (torch.Tensor): Classification model result 44 | z (torch.Tensor): Classification truth 45 | 46 | Returns: 47 | torch.Tensor: Computed loss for the model 48 | """ 49 | 50 | # Loss 1: Dice BCE loss for segmentation (seg_loss) 51 | if self.stage in (0, 1): 52 | intersection = (y_hat * y).sum() 53 | smooth = 1 # Dice BCE smoothing parameter, hardcoded to 1 does just fine 54 | dice_loss = 1 - (2. * intersection + 1) / ( 55 | y_hat.sum() + y.sum() + smooth) 56 | binary_ce = functional.binary_cross_entropy(y_hat, 57 | y, 58 | reduction='mean') 59 | self.seg_loss = binary_ce + dice_loss 60 | else: 61 | self.seg_loss = 0. 62 | 63 | # Loss 2: BCE Loss for estimating road obscuration (ob_loss) 64 | # This operates on the last logit produced by the model 65 | if self.stage in (0, 2): 66 | self.ob_loss = self.o_loss(z_hat[:, -1], z[:, -1]) 67 | else: 68 | self.ob_loss = 0. 69 | 70 | # Loss 3: Cross entropy for classification result (cl_loss) 71 | # This operates on all but the last model logit 72 | if self.stage in (0, 2): 73 | self.cl_loss = self.hxe_loss(z_hat[:, :-1], z[:, :-1]) 74 | else: 75 | self.cl_loss = 0. 76 | 77 | # Combine and return the combined loss 78 | loss = self.seg_k * self.seg_loss + self.ob_k * self.ob_loss + self.cl_loss 79 | return loss -------------------------------------------------------------------------------- /rsc/train/patch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Function 6 | from torch.nn.modules.pooling import _MaxUnpoolNd 7 | from torch.nn.modules.utils import _pair 8 | 9 | 10 | class MaxUnpool2dop(Function): 11 | """We warp the `torch.nn.functional.max_unpool2d` 12 | with an extra `symbolic` method, which is needed while exporting to ONNX. 13 | Users should not call this function directly. 14 | """ 15 | 16 | @staticmethod 17 | def forward(ctx, input, indices, kernel_size, stride, padding, 18 | output_size): 19 | """Forward function of MaxUnpool2dop. 20 | Args: 21 | input (Tensor): Tensor needed to upsample. 22 | indices (Tensor): Indices output of the previous MaxPool. 23 | kernel_size (Tuple): Size of the max pooling window. 24 | stride (Tuple): Stride of the max pooling window. 25 | padding (Tuple): Padding that was added to the input. 26 | output_size (List or Tuple): The shape of output tensor. 27 | Returns: 28 | Tensor: Output tensor. 29 | """ 30 | return F.max_unpool2d(input, indices, kernel_size, stride, padding, 31 | output_size) 32 | 33 | @staticmethod 34 | def symbolic(g, input, indices, kernel_size, stride, padding, output_size): 35 | # get shape 36 | input_shape = g.op('Shape', input) 37 | const_0 = g.op('Constant', value_t=torch.tensor(0)) 38 | const_1 = g.op('Constant', value_t=torch.tensor(1)) 39 | batch_size = g.op('Gather', input_shape, const_0, axis_i=0) 40 | channel = g.op('Gather', input_shape, const_1, axis_i=0) 41 | 42 | # height = (height - 1) * stride + kernel_size 43 | height = g.op('Gather', 44 | input_shape, 45 | g.op('Constant', value_t=torch.tensor(2)), 46 | axis_i=0) 47 | height = g.op('Sub', height, const_1) 48 | height = g.op('Mul', height, 49 | g.op('Constant', value_t=torch.tensor(stride[1]))) 50 | height = g.op('Add', height, 51 | g.op('Constant', value_t=torch.tensor(kernel_size[1]))) 52 | 53 | # width = (width - 1) * stride + kernel_size 54 | width = g.op('Gather', 55 | input_shape, 56 | g.op('Constant', value_t=torch.tensor(3)), 57 | axis_i=0) 58 | width = g.op('Sub', width, const_1) 59 | width = g.op('Mul', width, 60 | g.op('Constant', value_t=torch.tensor(stride[0]))) 61 | width = g.op('Add', width, 62 | g.op('Constant', value_t=torch.tensor(kernel_size[0]))) 63 | 64 | # step of channel 65 | channel_step = g.op('Mul', height, width) 66 | # step of batch 67 | batch_step = g.op('Mul', channel_step, channel) 68 | 69 | # channel offset 70 | range_channel = g.op('Range', const_0, channel, const_1) 71 | range_channel = g.op( 72 | 'Reshape', range_channel, 73 | g.op('Constant', value_t=torch.tensor([1, -1, 1, 1]))) 74 | range_channel = g.op('Mul', range_channel, channel_step) 75 | range_channel = g.op('Cast', range_channel, to_i=7) # 7 is int64 76 | 77 | # batch offset 78 | range_batch = g.op('Range', const_0, batch_size, const_1) 79 | range_batch = g.op( 80 | 'Reshape', range_batch, 81 | g.op('Constant', value_t=torch.tensor([-1, 1, 1, 1]))) 82 | range_batch = g.op('Mul', range_batch, batch_step) 83 | range_batch = g.op('Cast', range_batch, to_i=7) # 7 is int64 84 | 85 | # update indices 86 | indices = g.op('Add', indices, range_channel) 87 | indices = g.op('Add', indices, range_batch) 88 | 89 | return g.op('MaxUnpool', 90 | input, 91 | indices, 92 | kernel_shape_i=kernel_size, 93 | strides_i=stride) 94 | 95 | 96 | class MaxUnpool2d(_MaxUnpoolNd): 97 | """This module is modified from Pytorch `MaxUnpool2d` module. 98 | Args: 99 | kernel_size (int or tuple): Size of the max pooling window. 100 | stride (int or tuple): Stride of the max pooling window. 101 | Default: None (It is set to `kernel_size` by default). 102 | padding (int or tuple): Padding that is added to the input. 103 | Default: 0. 104 | """ 105 | 106 | def __init__(self, kernel_size, stride=None, padding=0): 107 | super(MaxUnpool2d, self).__init__() 108 | self.kernel_size = _pair(kernel_size) 109 | self.stride = _pair(stride or kernel_size) 110 | self.padding = _pair(padding) 111 | 112 | def forward(self, input, indices, output_size=None): 113 | """Forward function of MaxUnpool2d. 114 | Args: 115 | input (Tensor): Tensor needed to upsample. 116 | indices (Tensor): Indices output of the previous MaxPool. 117 | output_size (List or Tuple): The shape of output tensor. 118 | Default: None. 119 | Returns: 120 | Tensor: Output tensor. 121 | """ 122 | return MaxUnpool2dop.apply(input, indices, self.kernel_size, 123 | self.stride, self.padding, output_size) -------------------------------------------------------------------------------- /rsc/train/plmcnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import optuna 6 | import pytorch_lightning as pl 7 | 8 | from .data_augmentation import DataAugmentation 9 | 10 | from .mcnn import MaskCNN 11 | from .mcnn_loss import MCNNLoss 12 | 13 | 14 | class PLMaskCNN(pl.LightningModule): 15 | """ PyTorch Lightning wrapper for training the 16 | RSC MaskCNN """ 17 | 18 | def __init__(self, 19 | labels, 20 | top_level_map, 21 | weights, 22 | learning_rate: float = 1e-4, 23 | seg_k: float = 1.0, 24 | ob_k: float = 1.0, 25 | nc: int = 4): 26 | super().__init__() 27 | 28 | # Hyperparameters 29 | self.learning_rate = learning_rate 30 | self.seg_k = seg_k 31 | self.ob_k = ob_k 32 | self.labels = labels 33 | self.top_level_map = top_level_map 34 | self.weights = weights 35 | self.save_hyperparameters() 36 | 37 | # Number of channels 38 | self.nc = nc 39 | 40 | # Optuna trial (use set_optuna_trial) 41 | self.trial: optuna.trial.Trial | None = None 42 | 43 | # Stateful min val_loss_cl 44 | self.min_val_loss = float('inf') 45 | self.min_val_loss_im = float('inf') 46 | self.min_val_loss_cl = float('inf') 47 | self.min_val_loss_ob = float('inf') 48 | 49 | # Stateful learning rate 50 | self._lr = learning_rate 51 | 52 | self.transform = DataAugmentation(has_nir=(nc == 4)) 53 | self.loss = MCNNLoss(self.top_level_map, self.weights, 54 | self.seg_k, self.ob_k) 55 | 56 | # Labels: add 1 for "obscuartion" 57 | # Channels: add 1 for "mask" (e.g. RGB + mask, RGB + NIR + mask) 58 | self.model = MaskCNN(num_classes=len(self.labels) + 1, 59 | num_channels=nc + 1) 60 | 61 | def set_optuna_trial(self, trial: optuna.trial.Trial | None): 62 | self.trial = trial 63 | 64 | def set_stage(self, v, lr): 65 | first_stage = (self.model.encoder, self.model.decoder) 66 | second_stage = (self.model.encoder2, self.model.avgpool, self.model.fc) 67 | 68 | # Freeze / unfreeze components based on stage 69 | if v == 0: 70 | [e.unfreeze() for e in first_stage] 71 | [e.unfreeze() for e in second_stage] 72 | elif v == 1: 73 | [e.unfreeze() for e in first_stage] 74 | [e.freeze() for e in second_stage] 75 | elif v == 2: 76 | [e.freeze() for e in first_stage] 77 | [e.unfreeze() for e in second_stage] 78 | else: 79 | raise ValueError(f'Unknown v: {repr(v):s}') 80 | 81 | # Loss function requires stage 82 | self.loss.stage = v 83 | 84 | # Learning rate depends on stage 85 | self._lr = lr 86 | 87 | # Set stage 88 | self.stage = v 89 | 90 | def forward(self, x): 91 | return self.model(x) 92 | 93 | def training_step(self, batch, batch_idx): 94 | x, z = batch 95 | x, xpm = self.transform(x) 96 | y_hat, z_hat = self.forward(x) 97 | loss = self.loss(y_hat, xpm, z_hat, z) 98 | self.log_dict( 99 | { 100 | 'train_loss_im': self.loss.seg_loss, 101 | 'train_loss_cl': self.loss.cl_loss, 102 | 'train_loss_ob': self.loss.ob_loss, 103 | 'train_loss': loss, 104 | }, 105 | on_step=False, 106 | on_epoch=True) 107 | return loss 108 | 109 | def validation_step(self, batch, batch_idx): 110 | x, z = batch 111 | 112 | img_mask_c = slice(0, self.nc + 1) 113 | probmask_c = slice(self.nc + 1, self.nc + 2) 114 | 115 | # Create probmask, image + mask (be careful, order matters!) 116 | y = x[:, probmask_c, :, :] 117 | x = x[:, img_mask_c, :, :] 118 | 119 | y_hat, z_hat = self.forward(x) 120 | loss = self.loss(y_hat, y, z_hat, z) 121 | self.log_dict( 122 | { 123 | 'val_loss_im': self.loss.seg_loss, 124 | 'val_loss_cl': self.loss.cl_loss, 125 | 'val_loss_ob': self.loss.ob_loss, 126 | 'val_loss': loss, 127 | }, 128 | on_step=False, 129 | on_epoch=True) 130 | return loss 131 | 132 | def on_validation_epoch_end(self): 133 | 134 | metrics = self.trainer.logged_metrics 135 | this_val_loss = float(metrics['val_loss']) 136 | this_val_loss_im = float(metrics['val_loss_im']) 137 | this_val_loss_cl = float(metrics['val_loss_cl']) 138 | this_val_loss_ob = float(metrics['val_loss_ob']) 139 | 140 | if this_val_loss < self.min_val_loss: 141 | self.min_val_loss = this_val_loss 142 | self.min_val_loss_im = this_val_loss_im 143 | self.min_val_loss_cl = this_val_loss_cl 144 | self.min_val_loss_ob = this_val_loss_ob 145 | 146 | self.log_dict({ 147 | 'min_val_loss_im': self.min_val_loss_im, 148 | 'min_val_loss_cl': self.min_val_loss_cl, 149 | 'min_val_loss_ob': self.min_val_loss_ob, 150 | 'min_val_loss': self.min_val_loss, 151 | }) 152 | 153 | if self.trial is not None: 154 | self.trial.report(this_val_loss_cl, self.current_epoch) 155 | if self.trial.should_prune(): 156 | raise optuna.exceptions.TrialPruned() 157 | 158 | def configure_optimizers(self): 159 | return torch.optim.Adam(self.parameters(), lr=self._lr) 160 | -------------------------------------------------------------------------------- /rsc/train/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import kornia 7 | 8 | 9 | class PreProcess(nn.Module): 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.resize = kornia.augmentation.Resize(224, keepdim=True) 14 | 15 | @torch.no_grad() 16 | def forward(self, x) -> torch.Tensor: 17 | 18 | # NOTE: applies to full 6-channel input (image + mask + probmask) 19 | 20 | # Convert to tensor 21 | x = kornia.utils.image_to_tensor(x, keepdim=True).float() 22 | x = self.resize(x) 23 | 24 | # Normalize between 0 and 1 25 | x = torch.divide(x, 255.) 26 | 27 | return x -------------------------------------------------------------------------------- /rsc/train/rsc_hxe_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ Hierarchical Cross Entropy Loss """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class RSCHXELoss(nn.Module): 11 | """ Hierarchical Cross Entropy Loss (tailored to RSC model to handle obscuration estimates)""" 12 | 13 | def __init__(self, lv_b_a: torch.Tensor, lv_b_w: torch.Tensor): 14 | """ 15 | Init the loss function, describing a hierarchy 16 | 17 | (A1) (A2) <---- "Level A" | i.e. top level classes; e.g. 'paved' vs 'unpaved') | 18 | // \\ // \\ 19 | (B1) (B2) (B3) (B4) <---- "Level B" | i.e. prediction classes | 20 | | e.g. 'asphalt' vs 'concrete' vs 'dirt' vs 'gravel') | 21 | 22 | Args: 23 | lv_b_a (torch.Tensor): Mapping for each label to index of top-level (level A) hierarchy index 24 | e.g. from above example: (0, 0, 1, 1) 25 | lv_b_w (torch.Tensor): Weights for prediction (level B) classes, independent of hierarchy 26 | """ 27 | 28 | super().__init__() 29 | 30 | # Sanity check 31 | assert len(lv_b_a) == len(lv_b_w) 32 | device = lv_b_a.device 33 | 34 | # Persist Level B -> Level A mapping 35 | self.lv_b_a = lv_b_a 36 | 37 | # List of indices between Lv A elements and Lv B elements 38 | # NOTE: assumes we have a lv B node attached to every lv A node 39 | self.lv_a_idx = [torch.where(lv_b_a == e)[0] for e in range(int(lv_b_a.max() + 1))] 40 | 41 | # Proportion of Level A elements among level B 42 | lv_a_p = torch.tensor([(1 / len(lv_b_w) / lv_b_w[idx]).sum() for idx in self.lv_a_idx]) 43 | 44 | # Level A weights 45 | self.lv_a_w = ((1 / len(lv_a_p)) / lv_a_p).to(device) 46 | 47 | # Level B weights for each node in level A 48 | self.lv_b_w_a = torch.concat([(1 / lv_b_w[idx]).sum() * lv_b_w[idx] / len(idx) for idx in self.lv_a_idx]) 49 | 50 | def forward(self, logits: torch.Tensor, truth: torch.Tensor) -> torch.Tensor: 51 | """ 52 | Compute the loss given the model's classification results. 53 | 54 | Args: 55 | logits (torch.Tensor): Classification model output 56 | truth (torch.Tensor): Truth classification (one-hot encoded) Shape: (N, num_lv_b) 57 | 58 | Returns: 59 | torch.Tensor: Computed loss for the model 60 | """ 61 | 62 | # Compute the logarithm of the predicted probabilities 63 | log_y_pred = F.log_softmax(logits, dim=1) 64 | 65 | # Initialize the loss 66 | loss_lv_b = torch.Tensor((0,)).to(logits.device) 67 | loss_lv_a = torch.Tensor((0,)).to(logits.device) 68 | 69 | # Iterate over the categories 70 | for i in range(truth.shape[1]): 71 | # Compute the weight for the category 72 | w = self.lv_b_w_a[i] 73 | 74 | # Compute the weight for the upper layer 75 | w2 = self.lv_a_w[self.lv_b_a[i]] 76 | 77 | # Compute the cross entropy loss for the category 78 | cross_entropy = -torch.sum(truth[:, i] * log_y_pred[:, i]) 79 | 80 | # Add the weighted cross entropy loss to the total loss 81 | loss_lv_b += w * w2 * cross_entropy 82 | 83 | loss_lv_b /= truth.shape[1] 84 | 85 | # Compute level A logits 86 | logits_lv_a = torch.stack([torch.sum(logits[:, e], 1) for e in self.lv_a_idx], -1) 87 | log_y_pred_lv_a = F.log_softmax(logits_lv_a, dim=1) 88 | truth_lv_a = torch.stack([torch.sum(truth[:, e], 1) for e in self.lv_a_idx], -1) 89 | 90 | # Iterate over the categories 91 | for i in range(truth_lv_a.shape[1]): 92 | # Compute the weight for the category 93 | w = self.lv_a_w[i] 94 | 95 | # Compute the cross entropy loss for the category 96 | cross_entropy = -torch.sum(truth_lv_a[:, i] * log_y_pred_lv_a[:, i]) 97 | 98 | # Add the weighted cross entropy loss to the total loss 99 | loss_lv_a += w * cross_entropy 100 | 101 | loss_lv_a /= truth_lv_a.shape[1] 102 | 103 | return (loss_lv_b + loss_lv_a) / truth.shape[0] -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # type: ignore 4 | """ Road Surface Classifier Inference Example 5 | This is a quick script to perform inference with our trained model, so we can see how well it works! 6 | """ 7 | 8 | # %% 9 | 10 | import pathlib 11 | from tqdm import tqdm 12 | 13 | import torch 14 | from torch.utils.data import DataLoader 15 | 16 | from rsc.inference.mass_inference_dataset import MassInferenceDataset 17 | from rsc.model.preprocess import PreProcess 18 | 19 | # Input model and checkpoint paths (checkpoint contains the weights for inference) 20 | # Since these files are so large, they are not in source control. 21 | # Reach out to me if you'd like them. 22 | ckpt_path = pathlib.Path( 23 | '/data/road_surface_classifier/results/20230107_042006Z/model-0-epoch=10-val_loss=0.39906.ckpt' 24 | ) 25 | assert ckpt_path.exists() 26 | results_name = ckpt_path.parent.name 27 | ds_path = pathlib.Path('/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019.sqlite3') 28 | assert ds_path.exists() 29 | 30 | # %% Setup processing chain for inference 31 | from rsc.model.plmcnn import PLMaskCNN 32 | 33 | # Load model and set to eval mode 34 | model = PLMaskCNN.load_from_checkpoint(ckpt_path) 35 | model.eval() 36 | 37 | # Get label array (it's built into model) 38 | labels = model.__dict__.get('labels') 39 | 40 | # Import dataset 41 | preprocess = PreProcess() 42 | val_ds = MassInferenceDataset(ds_path, transform=preprocess) 43 | batch_size = 64 44 | val_dl = DataLoader(val_ds, 45 | num_workers=16, 46 | batch_size=batch_size, 47 | shuffle=True) 48 | 49 | #%% Iterate over the dataloader, and fetch predictions 50 | 51 | # Output data array 52 | output = [] 53 | 54 | for i, (osm_id, x) in tqdm(enumerate(iter(val_dl)), 55 | total=len(val_ds) // batch_size): 56 | 57 | # Get size of this batch 58 | sz = x.shape[0] 59 | 60 | # Predict with the model 61 | _, y_pred = model(x) 62 | 63 | # Compute argmax 64 | y_pred_am = torch.argmax(y_pred[:, 0:-1], dim=1) 65 | y_pred_am = y_pred_am.detach().numpy() # type: ignore 66 | 67 | # Compute softmax 68 | y_pred_sm = torch.softmax(y_pred, dim=1) 69 | y_pred_sm = y_pred_sm.detach().numpy() # type: ignore 70 | 71 | for j in range(sz): 72 | output.append((int(osm_id[j]), labels[y_pred_am[j]], *y_pred_sm[j, :])) 73 | 74 | #%% Save output as CSV 75 | import pandas as pd 76 | 77 | columns = ['osm_id', 'pred_label', *['pred_%s' % label for label in labels]] 78 | 79 | df = pd.DataFrame(output, columns=columns).set_index('osm_id') 80 | df.to_csv( 81 | f'/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019_results_{results_name}.csv') 82 | 83 | #%% Load in CSV output, and merge with original dataset SQLite file 84 | import sqlite3 85 | import pandas as pd 86 | 87 | df = pd.read_csv( 88 | f'/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019_results_{results_name}.csv' 89 | ).set_index('osm_id') 90 | 91 | with sqlite3.connect( 92 | 'file:/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019.sqlite3?mode=ro', 93 | uri=True) as con: 94 | df2 = pd.read_sql('SELECT * FROM features;', con).set_index('osm_id') 95 | 96 | df = df.join(df2) 97 | df 98 | 99 | #%% Save results into GPKG file for import to QGIS 100 | from osgeo import gdal, ogr, osr 101 | 102 | gdal.UseExceptions() 103 | ogr.UseExceptions() 104 | 105 | # Create SRS (EPSG:4326: WGS-84 decimal degrees) 106 | srs = osr.SpatialReference() 107 | srs.ImportFromEPSG(4326) 108 | 109 | driver: ogr.Driver = ogr.GetDriverByName('GPKG') 110 | ds: ogr.DataSource = driver.CreateDataSource( 111 | f'/data/road_surface_classifier/BOULDER_COUNTY_NAIP_2019_results_{results_name}.gpkg' 112 | ) 113 | layer: ogr.Layer = ds.CreateLayer( 114 | 'data', srs=srs, geom_type=ogr.wkbLineString) # type: ignore 115 | 116 | osm_id_field = ogr.FieldDefn('osm_id', ogr.OFTInteger64) 117 | highway_field = ogr.FieldDefn('highway', ogr.OFTString) 118 | surface_true_field = ogr.FieldDefn('surface_t', ogr.OFTString) 119 | surface_pred_field = ogr.FieldDefn('surface_p', ogr.OFTString) 120 | paved_conf = ogr.FieldDefn('paved_c', ogr.OFSTFloat32) 121 | unpaved_conf = ogr.FieldDefn('unpaved_c', ogr.OFSTFloat32) 122 | obsc_conf = ogr.FieldDefn('obsc_c', ogr.OFSTFloat32) 123 | 124 | layer.CreateField(osm_id_field) 125 | layer.CreateField(highway_field) 126 | layer.CreateField(surface_true_field) 127 | layer.CreateField(surface_pred_field) 128 | layer.CreateField(paved_conf) 129 | layer.CreateField(unpaved_conf) 130 | layer.CreateField(obsc_conf) 131 | 132 | feature_defn = layer.GetLayerDefn() 133 | 134 | for osm_id, row in df.iterrows(): 135 | 136 | poly = ogr.CreateGeometryFromWkt(row['wkt']) 137 | 138 | feat = ogr.Feature(feature_defn) 139 | 140 | feat.SetGeometry(poly) 141 | feat.SetField('osm_id', row.name) 142 | feat.SetField('highway', row['highway_tag']) 143 | feat.SetField('surface_t', row['surface_tag']) 144 | feat.SetField('surface_p', row['pred_label']) 145 | feat.SetField('paved_c', row['pred_paved']) 146 | feat.SetField('unpaved_c', row['pred_unpaved']) 147 | feat.SetField('obsc_c', row['pred_Obscured']) 148 | 149 | layer.CreateFeature(feat) 150 | poly = None 151 | feat = None 152 | 153 | layer = None # type: ignore 154 | ds = None # type: ignore 155 | 156 | #%% Convert the CSV file into another QGIS GPKG, but pruned to ways with 157 | # known surface types, such that we can evaluate the model's accuracy easier 158 | import sqlite3 159 | import pandas as pd 160 | from osgeo import gdal, ogr, osr 161 | 162 | gdal.UseExceptions() 163 | ogr.UseExceptions() 164 | 165 | # Read CSV, merge with dataset, extract the labels we want 166 | df = pd.read_csv( 167 | f'/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019_results_{results_name}.csv' 168 | ).set_index('osm_id') 169 | with sqlite3.connect( 170 | 'file:/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019.sqlite3?mode=ro', 171 | uri=True) as con: 172 | df2 = pd.read_sql('SELECT * FROM features;', con).set_index('osm_id') 173 | df = df.join(df2) 174 | df = df[['wkt', 'pred_label', 'surface_tag', 'highway_tag', 'pred_Obscured']] 175 | 176 | # Mapping of OSM surface types to simple paved / unpaved 177 | class_map = { 178 | 'asphalt': 'paved', 179 | 'bricks': 'paved', 180 | 'compacted': 'unpaved', 181 | 'concrete': 'paved', 182 | 'concrete:plates': 'paved', 183 | 'dirt': 'unpaved', 184 | 'gravel': 'unpaved', 185 | 'ground': 'unpaved', 186 | 'paved': 'paved', 187 | 'paving_stones': 'paved', 188 | 'unpaved': 'unpaved', 189 | } 190 | 191 | # Trim dataset + determine "correctness" 192 | df = df[df['surface_tag'] != ''] 193 | df['surface_tag'] = df['surface_tag'].apply(class_map.get) 194 | df['correct'] = df['surface_tag'] == df['pred_label'] 195 | 196 | # Create SRS (EPSG:4326: WGS-84 decimal degrees) 197 | srs = osr.SpatialReference() 198 | srs.ImportFromEPSG(4326) 199 | 200 | # Put into GPKG format for QGIS 201 | driver: ogr.Driver = ogr.GetDriverByName('GPKG') 202 | ds: ogr.DataSource = driver.CreateDataSource( 203 | f'/data/road_surface_classifier/BOULDER_COUNTY_NAIP_2019_results_eval_{results_name}.gpkg' 204 | ) 205 | layer: ogr.Layer = ds.CreateLayer('data', srs=srs, geom_type=ogr.wkbLineString) 206 | 207 | osm_id_field = ogr.FieldDefn('osm_id', ogr.OFTInteger64) 208 | highway_field = ogr.FieldDefn('highway', ogr.OFTString) 209 | surface_true_field = ogr.FieldDefn('surface_t', ogr.OFTString) 210 | surface_pred_field = ogr.FieldDefn('surface_p', ogr.OFTString) 211 | correct_field = ogr.FieldDefn('correct', ogr.OFTString) 212 | obsc_field = ogr.FieldDefn('obsc', ogr.OFTReal) 213 | 214 | layer.CreateField(osm_id_field) 215 | layer.CreateField(highway_field) 216 | layer.CreateField(surface_true_field) 217 | layer.CreateField(surface_pred_field) 218 | layer.CreateField(correct_field) 219 | layer.CreateField(obsc_field) 220 | 221 | feature_defn = layer.GetLayerDefn() 222 | 223 | for osm_id, row in df.iterrows(): 224 | 225 | poly = ogr.CreateGeometryFromWkt(row['wkt']) 226 | 227 | feat = ogr.Feature(feature_defn) 228 | 229 | feat.SetGeometry(poly) 230 | feat.SetField('osm_id', row.name) 231 | feat.SetField('highway', row['highway_tag']) 232 | feat.SetField('surface_t', row['surface_tag']) 233 | feat.SetField('surface_p', row['pred_label']) 234 | feat.SetField('correct', str(row['correct'])) 235 | feat.SetField('obsc', row['pred_Obscured']) 236 | 237 | layer.CreateFeature(feat) 238 | poly = None 239 | feat = None 240 | 241 | layer = None # type: ignore 242 | ds = None # type: ignore -------------------------------------------------------------------------------- /scripts/parse_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ Helper script used to inspect all available surface types within our dataset """ 4 | import pickle 5 | import pathlib 6 | 7 | from tqdm import tqdm 8 | from osgeo import gdal 9 | 10 | if __name__ == '__main__': 11 | 12 | feature_data = [] 13 | feature_data_pickle_path = pathlib.Path( 14 | '/data/road_surface_classifier/feature_data.pkl') 15 | 16 | if not feature_data_pickle_path.exists(): 17 | # Load road surface data from file 18 | ds = gdal.OpenEx( 19 | '/data/gis/us_road_surface/us_w_road_surface_filtered.gpkg') 20 | layer = ds.GetLayer() 21 | feature_count = layer.GetFeatureCount() 22 | 23 | # Extract features we care about 24 | print('Extracting features...') 25 | for idx in tqdm(range(feature_count)): 26 | feat = layer.GetNextFeature() 27 | wkt_str = feat.GetGeometryRef().ExportToWkt() 28 | osm_id = feat.GetField(0) 29 | highway = feat.GetField(1) 30 | surface = feat.GetField(2) 31 | feature_data.append((osm_id, wkt_str, highway, surface)) 32 | layer = None 33 | ds = None 34 | 35 | # Pickle the data so we don't have to process again 36 | with open(feature_data_pickle_path, 'wb') as f: 37 | pickle.dump(feature_data, f) 38 | 39 | else: 40 | # Load the data from file 41 | with open(feature_data_pickle_path, 'rb') as f: 42 | feature_data = pickle.load(f) 43 | 44 | # Some data cleanup 45 | # Get all surface types with more than 1000 labels 46 | surface_types = list(set([e[3].lower() for e in feature_data])) 47 | count_dict = {k: 0 for k in surface_types} 48 | for _, _, _, surface in feature_data: 49 | count_dict[surface.lower()] += 1 50 | del_keys = [k for k, v in count_dict.items() if v < 1000] 51 | [count_dict.pop(k) for k in del_keys] -------------------------------------------------------------------------------- /scripts/perform_query.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ OSM Overpass query to find all drivable roads with a labeled 'surface' tag """ 4 | from __future__ import annotations 5 | import pathlib 6 | 7 | from rsc.osm.overpass_api import OSMOverpassQuery, OSMOverpassResult 8 | 9 | 10 | class OSMCustomOverpassQuery(OSMOverpassQuery): 11 | """ Custom OSM Overpass API query for (hopefully) drivable road networks """ 12 | 13 | __slots__ = ['_highway_tags'] 14 | 15 | DEFAULT_HIGHWAY_TAGS = [ 16 | 'motorway', 'motorway_link', 'motorway_junction', 'trunk', 17 | 'trunk_link', 'primary', 'primary_link', 'secondary', 'secondary_link', 18 | 'tertiary', 'tertiary_link', 'unclassified', 'residential' 19 | ] 20 | 21 | def __init__(self, **kwargs): 22 | super().__init__(**kwargs) 23 | self._highway_tags = kwargs.get('highway_tags', 24 | self.DEFAULT_HIGHWAY_TAGS) 25 | 26 | def perform_query(self) -> OSMCustomOverpassResult: 27 | """ Perform an OSM Overpass API Request! """ 28 | return OSMCustomOverpassResult(self._perform_query()) 29 | 30 | @property 31 | def _query_str(self) -> str: 32 | # NOTE: OSMNX uses instead 33 | # ["highway"!~"abandoned|bridleway|bus_guideway|construction|corridor|cycleway|elevator|escalator|footway|path|pedestrian|planned|platform|proposed|raceway|service|steps|track"] 34 | 35 | return f""" 36 | [out:{self._format}] 37 | [timeout:{self._timeout}] 38 | [maxsize:2147483648]; 39 | (way["highway"] 40 | ["area"!~"yes"] 41 | ["access"!~"private"] 42 | ["highway"~"{'|'.join(self._highway_tags)}"] 43 | ["motor_vehicle"!~"no"] 44 | ["motorcar"!~"no"] 45 | ["surface"!~""] 46 | ["service"!~"alley|driveway|emergency_access|parking|parking_aisle|private"] 47 | (poly:'{self._poly_query_str}'); 48 | >; 49 | ); 50 | out; 51 | """ 52 | 53 | 54 | class OSMCustomOverpassResult(OSMOverpassResult): 55 | """ Container class for OSM Overpass result data for (hopefully) drivable road networks """ 56 | 57 | def to_file(self) -> None: 58 | """ Convert result data to an OSM network """ 59 | tmp_file = pathlib.Path('/data/gis/result') 60 | this_format = self.format 61 | if 'json' in this_format: 62 | tmp_file = tmp_file.with_suffix('.json') 63 | elif 'xml' in this_format: 64 | tmp_file = tmp_file.with_suffix('.osm') 65 | else: 66 | raise ValueError( 67 | 'Unknown format: \'%s\'! Cannot convert to OSM network.') 68 | 69 | # Save the response content to file 70 | tmp_file.write_bytes(self._result.content) 71 | 72 | 73 | if __name__ == '__main__': 74 | 75 | # Setup custom query to local interpreter 76 | q = OSMCustomOverpassQuery(format='xml', timeout=24 * 60 * 60) 77 | q.set_endpoint('http://localhost:12345/api/interpreter') 78 | 79 | # Use rough USA bounds for query 80 | with open('gis/us_wkt.txt', 'r') as f: 81 | us_wkt = f.read() 82 | q.set_poly_from_wkt(us_wkt) 83 | 84 | # Perform query and save! This will take a long time. 85 | print('Performing query...') 86 | result = q.perform_query() 87 | print('Saving to file...') 88 | result.to_file() -------------------------------------------------------------------------------- /scripts/pre_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ Pre-mass inference script. Loop through images and find chips of drivable ways. """ 4 | #%% Imports and inputs 5 | import sqlite3 6 | 7 | import pathlib 8 | from osgeo import gdal, ogr, osr 9 | import pandas as pd 10 | 11 | from rsc.common.utils import imread_geometry, imread_dims, map_to_pix 12 | from rsc.osm.overpass_api.road_network import OSMRoadNetworkOverpassQuery 13 | 14 | gdal.UseExceptions() 15 | ogr.UseExceptions() 16 | 17 | OVERPASS_INTERPRETER_URL = "http://localhost:12345/api/interpreter" 18 | TEST_CASE_PATH = pathlib.Path("/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019") 19 | assert TEST_CASE_PATH.is_dir() 20 | OUTPUT_PATH = pathlib.Path( 21 | '/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019.sqlite3') 22 | assert not OUTPUT_PATH.exists() 23 | 24 | # Setup custom query to local interpreter 25 | q = OSMRoadNetworkOverpassQuery(format='xml', timeout=24 * 60 * 60) 26 | q.set_endpoint('http://localhost:12345/api/interpreter') 27 | 28 | # Get EPSG:4326 SRS 29 | srs = osr.SpatialReference() 30 | srs.ImportFromEPSG(4326) 31 | srs.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER) 32 | 33 | #%% Loop through imagery and find drivable ways 34 | 35 | # Rows to convert to Pandas dataframe 36 | df_rows = [] 37 | 38 | for img_path in TEST_CASE_PATH.glob('*.mrf'): 39 | print('Processing %s...' % img_path.name) 40 | 41 | # Get geometry for image 42 | img_geom = imread_geometry(img_path, wgs84=True) 43 | h, w = imread_dims(img_path) 44 | 45 | # Perform overpass query for geometry 46 | q.set_poly_from_wkt(img_geom) 47 | 48 | # Perform the query 49 | result = q.perform_query() 50 | 51 | # Convert to road network 52 | network = result.to_network() 53 | 54 | # If no roads show up, skip for now 55 | if not network.num_ways: 56 | continue 57 | 58 | # Compute image-specific coordinate transformation 59 | ds: gdal.Dataset = gdal.Open(str(img_path), gdal.GA_ReadOnly) 60 | im_h, im_w = ds.RasterYSize, ds.RasterXSize 61 | g_xform = ds.GetGeoTransform() 62 | srs_ds: osr.SpatialReference = ds.GetSpatialRef() 63 | srs_ds.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER) 64 | ds = None # type: ignore 65 | c_xform = osr.CoordinateTransformation(srs, srs_ds) 66 | 67 | # Loop through all available ways 68 | for way in network.get_ways(): 69 | osm_id = way.id 70 | 71 | # Get nodes 72 | nodes = [network.get_node_by_id(id) 73 | for id in way.nodes] # type: ignore 74 | 75 | # Get linestring geometry 76 | linestr = ogr.Geometry(ogr.wkbLineString) 77 | [linestr.AddPoint_2D(node.lon, node.lat) for node in nodes] 78 | wkt = linestr.ExportToWkt() 79 | 80 | # Get midpoint lon, lat 81 | m_lon, m_lat = linestr.GetPoint_2D(linestr.GetPointCount() // 2) 82 | 83 | # Compute center point in image spatial coordinate system 84 | pt = ogr.Geometry(ogr.wkbPoint) 85 | pt.AddPoint_2D(m_lon, m_lat) 86 | pt.Transform(c_xform) 87 | 88 | # Compute upper-left corner for image chip 89 | x, y = [e[0].item() for e in map_to_pix(g_xform, pt.GetX(), pt.GetY())] 90 | if (x < 0 or x >= im_w) or (y < 0 or y >= im_h): 91 | # Center of way off of image, skipping. 92 | # NOTE: this OSM_ID may appear in another image 93 | continue 94 | x1, y1 = x - 128, y - 128 95 | x2, y2 = x1 + 256, y1 + 256 96 | 97 | # Handle boundaries 98 | if x1 < 0: 99 | x1, x2 = 0, 256 100 | if y1 < 0: 101 | y1, y2 = 0, 256 102 | if x2 >= w: 103 | x1, x2 = w - 256 - 1, w - 1 104 | if y2 >= h: 105 | y1, y2 = h - 256 - 1, h - 1 106 | 107 | # Fetch tags 108 | highway_tag = way.tags.get('highway', '') 109 | surface_tag = way.tags.get('surface', '') 110 | 111 | df_rows.append([ 112 | osm_id, img_path.name, wkt, m_lon, m_lat, x1, y1, x2, y2, 113 | highway_tag, surface_tag 114 | ]) 115 | 116 | #%% Export data to SQLite3 database 117 | 118 | df = pd.DataFrame(df_rows, 119 | columns=[ 120 | 'osm_id', 'img', 'wkt', 'm_lon', 'm_lat', 'x1', 'y1', 121 | 'x2', 'y2', 'highway_tag', 'surface_tag' 122 | ]).set_index('osm_id') 123 | with sqlite3.connect(OUTPUT_PATH) as con: 124 | df.to_sql('features', con) 125 | -------------------------------------------------------------------------------- /scripts/run_overpass_api.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # https://hub.docker.com/r/wiktorn/overpass-api 4 | # http://localhost:12345/api/interpreter 5 | docker run \ 6 | -e OVERPASS_META=yes \ 7 | -e OVERPASS_MODE=init \ 8 | -e OVERPASS_PLANET_URL=file:///data/gis/us-latest.osm.bz2 \ 9 | -e OVERPASS_RULES_LOAD=10 \ 10 | -e OVERPASS_SPACE=55000000000 \ 11 | -e OVERPASS_MAX_TIMEOUT=86400 \ 12 | -v /data/gis:/data/gis \ 13 | -v /data/gis/overpass_db:/db \ 14 | -p 12345:80 \ 15 | -i --name overpass_usa wiktorn/overpass-api:latest 16 | -------------------------------------------------------------------------------- /scripts/sample_augmentations.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | 10 | def sample_augmentations(ds, transform, grid=5): 11 | 12 | dl = DataLoader(ds, num_workers=1, batch_size=1, shuffle=True) 13 | for x, _ in dl: 14 | break 15 | x = x[0, ...] # type: ignore 16 | x_np = x.numpy() 17 | im_np = np.moveaxis(x_np[0:4, :, :] * 255., 0, -1).astype(np.uint8) 18 | mask_np = np.moveaxis(x_np[4:5, :, :] * 255., 0, -1).astype(np.uint8) 19 | pmask_np = np.moveaxis(x_np[5:6, :, :] * 255., 0, -1).astype(np.uint8) 20 | 21 | fig, ax = plt.subplots(grid, 22 | 4, 23 | sharex=True, 24 | sharey=True, 25 | figsize=(3 * 4, 3 * grid)) 26 | ax = ax.flatten() # type: ignore 27 | 28 | ax[0].imshow(im_np[..., (0, 1, 2)]) # type: ignore 29 | ax[1].imshow(im_np[..., (3, 0, 1)]) # type: ignore 30 | ax[2].imshow(mask_np) 31 | ax[3].imshow(pmask_np) 32 | 33 | # Plot images 34 | for idx in range(4, len(ax), 4): 35 | # Do an augmentation 36 | im_aug, pm_aug = transform(x) 37 | m_aug = im_aug[:, 4:5, ...] 38 | im_aug = im_aug[:, 0:4, ...] 39 | 40 | im_aug_np = (np.moveaxis(im_aug[0, ...].numpy(), 0, -1) * 255.).astype( 41 | np.uint8) 42 | m_aug_np = np.moveaxis(m_aug[0, ...].numpy() * 255., 0, 43 | -1).astype(np.uint8) 44 | pm_aug_np = np.moveaxis(pm_aug[0, ...].numpy() * 255., 0, 45 | -1).astype(np.uint8) 46 | 47 | # Plot it 48 | ax[idx].imshow(im_aug_np[..., (0, 1, 2)]) # type: ignore 49 | ax[idx + 1].imshow(im_aug_np[..., (3, 0, 1)]) # type: ignore 50 | ax[idx + 2].imshow(m_aug_np) 51 | ax[idx + 3].imshow(pm_aug_np) 52 | 53 | for _ax in ax: 54 | _ax.get_xaxis().set_visible(False) 55 | _ax.get_yaxis().set_visible(False) 56 | fig.tight_layout() 57 | 58 | plt.show() 59 | 60 | 61 | if __name__ == '__main__': 62 | 63 | from rsc.train.dataset import RoadSurfaceDataset 64 | from rsc.train.preprocess import PreProcess 65 | from rsc.train.data_augmentation import DataAugmentation 66 | 67 | train_ds = RoadSurfaceDataset( 68 | '/data/road_surface_classifier/dataset/dataset_train.csv', 69 | transform=PreProcess(), 70 | limit=50) 71 | 72 | sample_augmentations(train_ds, DataAugmentation()) -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | from typing import Optional 7 | 8 | sys.path.append( 9 | os.getcwd()) # TODO: this is silly, would be fixed with pip install 10 | 11 | # Set AWS profile (for use in MLFlow) 12 | os.environ["AWS_PROFILE"] = 'truenas' 13 | os.environ["MLFLOW_S3_ENDPOINT_URL"] = 'http://truenas.local:9807' 14 | 15 | import pathlib 16 | from datetime import datetime 17 | 18 | import pandas as pd 19 | 20 | import torch 21 | from torch.utils.data import DataLoader 22 | import pytorch_lightning as pl 23 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, StochasticWeightAveraging 24 | from pytorch_lightning.loggers import MLFlowLogger 25 | 26 | from rsc.train.plmcnn import PLMaskCNN 27 | from rsc.train.preprocess import PreProcess 28 | from rsc.train.dataset import RoadSurfaceDataset 29 | from rsc.artifacts.confusion_matrix_handler import ConfusionMatrixHandler 30 | 31 | # Recommended for CUDA 32 | torch.set_float32_matmul_precision('medium') 33 | 34 | # Whether or not to include the NIR channel 35 | # since the input dataset includes it 36 | INCLUDE_NIR = True 37 | 38 | # Quick test, train an epic on fewer-than-normal 39 | # set of images to check everything works 40 | QUICK_TEST = False 41 | 42 | if __name__ == '__main__': 43 | 44 | # Create directory for results 45 | results_dir = pathlib.Path( 46 | '/data/road_surface_classifier/results').resolve() 47 | assert results_dir.is_dir() 48 | now = datetime.utcnow().strftime( 49 | '%Y%m%d_%H%M%SZ') # timestamp string used for traceability 50 | save_dir = results_dir / now 51 | save_dir.mkdir(parents=False, exist_ok=False) 52 | 53 | # Log epoch + validation loss to CSV 54 | #logger = CSVLogger(str(save_dir)) 55 | 56 | # Labels and weights 57 | weights_df = pd.read_csv( 58 | '/data/road_surface_classifier/dataset_multiclass/class_weights.csv') 59 | labels = list(weights_df['class_name']) 60 | top_level_map = list(weights_df['top_level']) 61 | class_weights = list(weights_df['weight']) 62 | 63 | # Get dataset 64 | chip_size=224 65 | preprocess = PreProcess() 66 | train_ds = RoadSurfaceDataset( 67 | '/data/road_surface_classifier/dataset_multiclass/dataset_train.csv', 68 | transform=preprocess, 69 | chip_size=chip_size, 70 | limit=-1 if not QUICK_TEST else 500, 71 | n_channels=4 if INCLUDE_NIR else 3) 72 | val_ds = RoadSurfaceDataset( 73 | '/data/road_surface_classifier/dataset_multiclass/dataset_val.csv', 74 | transform=preprocess, 75 | chip_size=chip_size, 76 | limit=-1 if not QUICK_TEST else 500, 77 | n_channels=4 if INCLUDE_NIR else 3) 78 | 79 | # Create data loaders. 80 | batch_size = 64 81 | train_dl = DataLoader(train_ds, 82 | num_workers=16, 83 | batch_size=batch_size, 84 | shuffle=True) 85 | val_dl = DataLoader(val_ds, num_workers=16, batch_size=batch_size) 86 | 87 | # Model 88 | learning_rate=0.00040984645874638675 89 | model = PLMaskCNN(weights=class_weights, 90 | labels=labels, 91 | nc=4 if INCLUDE_NIR else 3, 92 | top_level_map=top_level_map, 93 | learning_rate=learning_rate, 94 | seg_k=0.7, 95 | ob_k=0.9) 96 | 97 | 98 | # Save model to results directory 99 | torch.save(model, save_dir / 'model.pth') 100 | 101 | # Attempt deserialization (b/c) I've had problems with it before 102 | try: 103 | torch.load(save_dir / 'model.pth') 104 | except: 105 | import traceback 106 | traceback.print_exc() 107 | raise AssertionError('Torch model failed to deserialize!') 108 | 109 | # Train model in stages 110 | best_model_path: Optional[str] = None 111 | mlflow_logger: Optional[MLFlowLogger] = None 112 | stage = 0 113 | model.set_stage(stage, learning_rate) 114 | 115 | 116 | # Logger 117 | mlflow_logger = MLFlowLogger(experiment_name='road_surface_classifier', 118 | run_name='run_%s_%d' % (now, stage), 119 | tracking_uri='http://truenas.local:9809') 120 | 121 | # Upload base model 122 | mlflow_logger.experiment.log_artifact(mlflow_logger.run_id, 123 | str(save_dir / 'model.pth')) 124 | 125 | # Save checkpoints (model states for later) 126 | checkpoint_callback = ModelCheckpoint( 127 | dirpath=str(save_dir), 128 | monitor='val_loss', 129 | save_top_k=3, 130 | filename='model-%d-{epoch:02d}-{val_loss:.5f}' % stage) 131 | 132 | # Setup early stopping based on validation loss 133 | early_stopping_callback = EarlyStopping(monitor='val_loss', 134 | mode='min', 135 | patience=10) 136 | 137 | # Stochastic Weight Averaging 138 | swa_callback = StochasticWeightAveraging(swa_lrs=0.6863423660749621) 139 | 140 | # Trainer 141 | trainer = pl.Trainer(accelerator='gpu', 142 | devices=1, 143 | max_epochs=1000 if not QUICK_TEST else 1, 144 | callbacks=[ 145 | checkpoint_callback, early_stopping_callback, 146 | swa_callback 147 | ], 148 | logger=mlflow_logger) 149 | 150 | # Do the thing! 151 | trainer.fit(model, train_dataloaders=train_dl, 152 | val_dataloaders=val_dl) # type: ignore 153 | 154 | # Get best model path 155 | best_model_path = checkpoint_callback.best_model_path 156 | 157 | # Upload best model 158 | mlflow_logger.experiment.log_artifact(mlflow_logger.run_id, 159 | best_model_path) 160 | 161 | # Load model at best checkpoint for next stage 162 | del model # just to be safe 163 | model = PLMaskCNN.load_from_checkpoint(best_model_path) 164 | 165 | assert best_model_path is not None 166 | assert mlflow_logger is not None 167 | 168 | # Generate artifacts 169 | print('Generating artifacts...') 170 | model.eval() 171 | 172 | # Generate artifacts from model 173 | from rsc.artifacts import ArtifactGenerator 174 | from rsc.artifacts.confusion_matrix_handler import ConfusionMatrixHandler 175 | from rsc.artifacts.accuracy_obsc_handler import AccuracyObscHandler 176 | from rsc.artifacts.obsc_compare_handler import ObscCompareHandler 177 | from rsc.artifacts.samples_handler import SamplesHandler 178 | artifacts_dir = save_dir / 'artifacts' 179 | artifacts_dir.mkdir(parents=False, exist_ok=True) 180 | generator = ArtifactGenerator(artifacts_dir, model, val_dl) 181 | generator.add_handler(ConfusionMatrixHandler(simple=False)) 182 | generator.add_handler(ConfusionMatrixHandler(simple=True)) 183 | generator.add_handler(AccuracyObscHandler()) 184 | generator.add_handler(ObscCompareHandler()) 185 | generator.add_handler(SamplesHandler()) 186 | generator.run(raise_on_error=False) 187 | mlflow_logger.experiment.log_artifacts(mlflow_logger.run_id, 188 | str(artifacts_dir)) 189 | -------------------------------------------------------------------------------- /scripts/train_optuna.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | import traceback 7 | from typing import Optional 8 | 9 | sys.path.append( 10 | os.getcwd()) # TODO: this is silly, would be fixed with pip install 11 | 12 | # Set AWS profile (for use in MLFlow) 13 | os.environ["AWS_PROFILE"] = 'truenas' 14 | os.environ["MLFLOW_S3_ENDPOINT_URL"] = 'http://truenas:9807' 15 | 16 | import pathlib 17 | from datetime import datetime 18 | 19 | import pandas as pd 20 | 21 | import torch 22 | from torch.utils.data import DataLoader 23 | import pytorch_lightning as pl 24 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, StochasticWeightAveraging 25 | from pytorch_lightning.loggers import MLFlowLogger 26 | 27 | import optuna 28 | 29 | from rsc.train.plmcnn import PLMaskCNN 30 | from rsc.train.preprocess import PreProcess 31 | from rsc.train.dataset import RoadSurfaceDataset 32 | 33 | # Recommended for CUDA 34 | torch.set_float32_matmul_precision('medium') 35 | 36 | # Whether or not to include the NIR channel 37 | # since the input dataset includes it 38 | INCLUDE_NIR = True 39 | 40 | # Quick test, train an epic on fewer-than-normal 41 | # set of images to check everything works 42 | QUICK_TEST = False 43 | 44 | if __name__ == '__main__': 45 | 46 | # Optuna config 47 | num_trials = 50 48 | now = datetime.utcnow().strftime( 49 | '%Y%m%d_%H%M%SZ') # timestamp string used for traceability 50 | study_name = 'study_%s' % now 51 | 52 | # Labels and weights 53 | weights_df = pd.read_csv( 54 | '/data/road_surface_classifier/dataset_multiclass/class_weights.csv') 55 | labels = list(weights_df['class_name']) 56 | top_level_map = list(weights_df['top_level']) 57 | class_weights = list(weights_df['weight']) 58 | 59 | # Get dataset 60 | preprocess = PreProcess() 61 | train_ds = RoadSurfaceDataset( 62 | '/data/road_surface_classifier/dataset_multiclass/dataset_train.csv', 63 | transform=preprocess, 64 | chip_size=224, 65 | limit=-1 if not QUICK_TEST else 1500, 66 | n_channels=4 if INCLUDE_NIR else 3) 67 | val_ds = RoadSurfaceDataset( 68 | '/data/road_surface_classifier/dataset_multiclass/dataset_val.csv', 69 | transform=preprocess, 70 | chip_size=224, 71 | limit=-1 if not QUICK_TEST else 500, 72 | n_channels=4 if INCLUDE_NIR else 3) 73 | 74 | def objective(trial: optuna.trial.Trial): 75 | global train_ds, val_ds 76 | 77 | # Create data loaders. 78 | batch_size = 64 79 | train_dl = DataLoader(train_ds, 80 | num_workers=16, 81 | batch_size=batch_size, 82 | shuffle=True) 83 | val_dl = DataLoader(val_ds, num_workers=16, batch_size=batch_size) 84 | 85 | # Hyperparameters 86 | chip_size = 224 87 | learning_rate = trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True) 88 | swa_lrs = trial.suggest_float("swa_lrs", 1e-4, 1e0, log=True) 89 | seg_k = trial.suggest_float("seg_k", 0.1, 1, step=0.1) 90 | ob_k = trial.suggest_float("ob_k", 0.1, 1, step=0.1) 91 | 92 | # Set dataset parameters 93 | train_ds.set_chip_size(chip_size) 94 | val_ds.set_chip_size(chip_size) 95 | 96 | # Set model parameters 97 | model = PLMaskCNN(weights=class_weights, 98 | labels=labels, 99 | nc=4 if INCLUDE_NIR else 3, 100 | top_level_map=top_level_map, 101 | learning_rate=learning_rate, 102 | seg_k=seg_k, 103 | ob_k=ob_k) 104 | 105 | # Create directory for results 106 | results_dir = pathlib.Path( 107 | '/data/road_surface_classifier/results').resolve() 108 | assert results_dir.is_dir() 109 | now = datetime.utcnow().strftime( 110 | '%Y%m%d_%H%M%SZ') # timestamp string used for traceability 111 | save_dir = results_dir / now 112 | save_dir.mkdir(parents=False, exist_ok=False) 113 | 114 | # Save model to results directory 115 | torch.save(model, save_dir / 'model.pth') 116 | 117 | # Attempt deserialization (b/c) I've had problems with it before 118 | try: 119 | torch.load(save_dir / 'model.pth') 120 | except: 121 | traceback.print_exc() 122 | raise AssertionError('Torch model failed to deserialize!') 123 | 124 | # Train model in stages 125 | best_model_path: Optional[str] = None 126 | mlflow_logger: Optional[MLFlowLogger] = None 127 | stage = 0 128 | 129 | # Logger 130 | mlflow_logger = MLFlowLogger( 131 | experiment_name='road_surface_classifier', 132 | run_name='run_%s_%d_trial_%d' % (now, stage, trial.number), 133 | tracking_uri='http://truenas:9809') 134 | mlflow_logger.log_hyperparams(dict(chip_size=chip_size, swa_lrs=swa_lrs)) 135 | 136 | # Upload base model 137 | mlflow_logger.experiment.log_artifact(mlflow_logger.run_id, 138 | str(save_dir / 'model.pth')) 139 | 140 | # Save checkpoints (model states for later) 141 | checkpoint_callback = ModelCheckpoint( 142 | dirpath=str(save_dir), 143 | monitor='val_loss', 144 | save_top_k=1, 145 | filename='model-%d-{epoch:02d}-{val_loss:.5f}' % stage) 146 | 147 | # Setup early stopping based on validation loss 148 | early_stopping_callback = EarlyStopping(monitor='val_loss', 149 | mode='min', 150 | patience=10) 151 | 152 | # Stochastic Weight Averaging 153 | swa_callback = StochasticWeightAveraging(swa_lrs=swa_lrs) 154 | 155 | # Trainer 156 | trainer = pl.Trainer( 157 | accelerator='gpu', 158 | devices=1, 159 | max_epochs=300, 160 | callbacks=[checkpoint_callback, early_stopping_callback, swa_callback], 161 | logger=mlflow_logger) 162 | 163 | # Do the thing! 164 | trainer.fit(model, train_dataloaders=train_dl, 165 | val_dataloaders=val_dl) # type: ignore 166 | 167 | # Get best model path 168 | best_model_path = checkpoint_callback.best_model_path 169 | 170 | # Upload best model 171 | mlflow_logger.experiment.log_artifact(mlflow_logger.run_id, 172 | best_model_path) 173 | 174 | # Objective return values 175 | ret = (model.min_val_loss_cl, model.min_val_loss_ob) 176 | 177 | # Load model at best checkpoint to generate artifacts 178 | del model # just to be safe 179 | model = PLMaskCNN.load_from_checkpoint(best_model_path) 180 | 181 | assert best_model_path is not None 182 | assert mlflow_logger is not None 183 | 184 | # Generate artifacts 185 | print('Generating artifacts...') 186 | model.eval() 187 | 188 | # Generate artifacts from model 189 | from rsc.artifacts import ArtifactGenerator 190 | from rsc.artifacts.confusion_matrix_handler import ConfusionMatrixHandler 191 | from rsc.artifacts.accuracy_obsc_handler import AccuracyObscHandler 192 | from rsc.artifacts.obsc_compare_handler import ObscCompareHandler 193 | artifacts_dir = save_dir / 'artifacts' 194 | artifacts_dir.mkdir(parents=False, exist_ok=True) 195 | generator = ArtifactGenerator(artifacts_dir, model, val_dl) 196 | generator.add_handler(ConfusionMatrixHandler(simple=False)) 197 | generator.add_handler(ConfusionMatrixHandler(simple=True)) 198 | generator.add_handler(AccuracyObscHandler()) 199 | generator.add_handler(ObscCompareHandler()) 200 | generator.run(raise_on_error=False) 201 | mlflow_logger.experiment.log_artifacts(mlflow_logger.run_id, 202 | str(artifacts_dir)) 203 | 204 | return ret 205 | 206 | # Do the work! 207 | for _ in range(3): 208 | try: 209 | study = optuna.create_study( 210 | directions=['minimize', 'minimize'], 211 | storage='sqlite:///./optuna_rsc.sqlite3', 212 | study_name='study_20240108_000000Z', 213 | load_if_exists=True) 214 | study.optimize(objective, n_trials=num_trials) 215 | except Exception: 216 | traceback.print_exc() 217 | print('Restarting study...') 218 | else: 219 | break 220 | 221 | print('Done!') --------------------------------------------------------------------------------