├── .gitignore
├── LICENSE
├── README.md
├── figures
├── confusion_matrix.png
├── confusion_matrix_simplified.png
├── naip_example.png
├── naip_masks_example.png
├── osm_logo.png
├── rsc_diagram.drawio
├── rsc_diagram.drawio.png
├── rsc_road_small.svg
└── samples
│ ├── correct
│ ├── paved
│ │ ├── sample_00000.png
│ │ ├── sample_00001.png
│ │ ├── sample_00002.png
│ │ ├── sample_00003.png
│ │ └── sample_00004.png
│ └── unpaved
│ │ ├── sample_00000.png
│ │ ├── sample_00001.png
│ │ ├── sample_00002.png
│ │ ├── sample_00003.png
│ │ └── sample_00004.png
│ └── incorrect
│ ├── paved
│ ├── sample_00000.png
│ ├── sample_00001.png
│ ├── sample_00002.png
│ ├── sample_00003.png
│ └── sample_00004.png
│ └── unpaved
│ ├── sample_00000.png
│ ├── sample_00001.png
│ ├── sample_00002.png
│ ├── sample_00003.png
│ └── sample_00004.png
├── notebooks
├── 00_create_naip_on_aws_gpkg.ipynb
├── 01_explore_osm_surface.ipynb
├── 02_data_prep.ipynb
├── 03_dataset_to_gpkg.ipynb
├── 04_road_color_analysis.ipynb
├── 05_segmentation.ipynb
├── 06_hierarchical_loss.ipynb
└── 07_model_calibration.ipynb
├── pyproject.toml
├── requirements.in
├── requirements.txt
├── rsc
├── artifacts
│ ├── __init__.py
│ ├── __main__.py
│ ├── accuracy_obsc_handler.py
│ ├── auc_handler.py
│ ├── base.py
│ ├── confusion_matrix_handler.py
│ ├── obsc_compare_handler.py
│ └── samples_handler.py
├── common
│ ├── __init__.py
│ ├── aws_naip.py
│ ├── geometric_median.py
│ └── utils.py
├── inference
│ ├── __init__.py
│ ├── fetch.py
│ └── mass_inference_dataset.py
├── osm
│ ├── README.md
│ ├── __init__.py
│ ├── osm_element.py
│ ├── osm_element_factory.py
│ ├── osm_network.py
│ ├── osm_overpass_api.py
│ └── overpass_api
│ │ ├── __init__.py
│ │ ├── osm_overpass_api.py
│ │ └── road_network.py
└── train
│ ├── __init__.py
│ ├── color_jitter_nohuesat.py
│ ├── data_augmentation.py
│ ├── dataset.py
│ ├── mcnn.py
│ ├── mcnn_loss.py
│ ├── patch.py
│ ├── plmcnn.py
│ ├── preprocess.py
│ └── rsc_hxe_loss.py
└── scripts
├── evaluate.py
├── parse_features.py
├── perform_query.py
├── pre_inference.py
├── run_overpass_api.sh
├── sample_augmentations.py
├── train.py
└── train_optuna.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .style.yapf
2 | *_wip\.py
3 | cache/
4 | .vscode
5 | .devcontainer
6 | *\.pyc
7 | data/
8 | *.code-workspace
9 | *.tar.gz
10 | *.npz
11 | build/
12 | *.egg-info
13 | venv/
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Jonathan Dalrymple
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OpenStreetMap Road Surface Classifier
2 |
3 |
4 |
5 | This project leverages machine learning (ML) to tag "drivable ways" (roads) in OpenStreetMap **with > 95% accuracy**. The main focus at the moment is automating tagging of the road surface type (paved vs. unpaved; but skill is shown for asphalt, concrete, gravel, etc.), with other helpful tags such as number of lanes to be added in the future.
6 |
7 | > :warning: **Much of this code is under active development. Breaking changes are to be expected.**
8 |
9 | ## Table of Contents
10 | - [Why Should I Care...](#why-should-i-care)
11 | - [Integration with the OSM Community](#integration-with-the-osm-community)
12 | - [Python Environment](#python-environment)
13 | - [Training Dataset](#training-dataset)
14 | - [Model Architecture](#model-architecture)
15 | - [Model Training](#model-training)
16 | - [Training Results](#training-results)
17 | - [Credits](#credits)
18 |
19 | ## Why Should I Care...
20 |
21 | ### ... About OpenStreetMap?
22 |
23 | Most geospatial data, like that provided in Google Maps, are locked behind licensing and paywalls. Maps should not be proprietary. OpenStreetMap aims to change this, by providing open geospatial data for everyone. [See more here](https://wiki.openstreetmap.org/wiki/FAQ#Why_OpenStreetMap?).
24 |
25 | ### ... About this Project?
26 |
27 | Road surface type is critical for routing applications to generate useful routes. For example, nominal driving speeds are much slower on unpaved roads vs. paved roads. For bicycles, unpaved routes may need to be avoided completely. In any case, lacking knowledge of road surface type can lead any OSM-based routing application to choose suboptimal routes [if the assumed default surface values are incorrect](https://wiki.openstreetmap.org/wiki/Key:surface#Default_values). Widespread labeling of road surface types can increase overall confidence in OSM-based routers as a viable routing solution for cars and bicycles alike.
28 |
29 | ## Integration with the OSM Community
30 |
31 | I foresee the following possible integrations with the OpenStreetMap community:
32 | - Provide a dataset that can augment OSM routing solutions, such as [Project OSRM](https://project-osrm.org/), [Valhalla](https://github.com/valhalla/valhalla), and [cycle.travel](https://cycle.travel/).
33 | - Integrate the above dataset into editors such as [JOSM](https://josm.openstreetmap.de/) or [Rapid](https://rapideditor.org/). I'm prototyping a plugin for JOSM at the moment.
34 |
35 | ## Python Environment
36 | I recommend [pip-tools](https://github.com/jazzband/pip-tools) to manage the environment for this project. The gist for to set up a Python environment for this repo would be:
37 | ```bash
38 | $ cd /.../road_surface_classifier # cd into project directory
39 | $ python3 -m venv ./venv # create Python virtual env
40 | $ . venv/bin/activate # activate env
41 | (venv) $ python3 -m pip install pip-tools # install pip-tools
42 | (venv) $ pip-sync # automatically installs dependencies!
43 | ```
44 |
45 | ## Training Dataset
46 | The dataset used in this project was prepared by the process outlined in [project notebooks](./notebooks). It is the aggregation of OpenStreetMap data w/ [National Agriculture Imagery Program (NAIP)](https://www.usgs.gov/centers/eros/science/usgs-eros-archive-aerial-photography-national-agriculture-imagery-program-naip) imagery, which is public domain. I additionally have a routine that generates pseudo-truth segmentation masks which the model learns to predict.
47 |
48 | Currently the dataset has been automatically generated, but in some cases has issues related to the visibility of the roads due to vegetation growth (though the model is trained to predict this). Also of note is that there is no guarantee the labels set in OSM are correct, so we must trust that the vast majority of them are correct. A big source of confusion, for example, is asphalt vs. concrete. It would not surprise me if there are many mislabeled examples of these within OSM.
49 |
50 |
51 | > :heavy_exclamation_mark: NAIP imagery is limited to the United States. While there are other public domain imagery sources that can be used, none have global coverage.
52 |
53 |
54 |
55 | Examples of NAIP imagery over roads. These are Web Mercator tiles at zoom level 16 (~2.3 meters per pixel). However, the model is trained on raw NAIP data which includes NIR and is 1 meter per pixel. (source: USGS National Agriculture Imagery Program)
56 |
57 |
58 | To support the MaskCNN architecture (_see below_), binary masks were also generated in order to tell the model "where to look":
59 |
60 |
61 |
62 | Examples of NAIP imagery over roads with OSM labels (paved vs. unpaved) and generated binary masks from OSM data. (source: USGS National Agriculture Imagery Program [imagery]; OpenStreetMap [labels])
63 |
64 |
65 | ## Model Architecture
66 |
67 | I'm currently using a MaskCNN model largely based on [Liu et al.: _Masked convolutional neural network for supervised learning problems_](https://par.nsf.gov/servlets/purl/10183705).
68 | - Instead of multiplication, I concatenate the predicted mask into the classifier backbone.
69 | - I'm using a Resnet-18 backbone for both the encoder, decoder, and classifier.
70 | - By using such a small encoder, this can inference on a CPU!
71 |
72 |
73 | Quick diagram of the MaskCNN architecture used here. The NAIP imagery gets combined with a mask created from OSM vector data, which in-turn is used to generate the segmentation mask. The image and segmentation mask are then fed into the classifier model.
74 |
75 |
76 | The benefit of this model over a plain Resnet is the ability to tell the model what the mask should look like. This tells the classifier "where to look" (i.e. I care about _this_ road in the image, not _that_ one).
77 |
78 | The trick is to not force the appearance of this mask too much, because then (1) the model stops looking outside the mask after the concatenation step and (2) the model will care about the mask more than the classification result!
79 |
80 | ## Model Training
81 |
82 | Training is currently done w/ [PyTorch Lightning](https://www.pytorchlightning.ai/), see [`train.py`](./model/train.py).
83 |
84 | > :heavy_exclamation_mark: I don't recommend training this model without a dedicated compute GPU configured with CUDA. I know some use [Google Colab](https://colab.research.google.com/), but I'm unfamiliar.
85 |
86 | ## Training Results
87 |
88 | **To read the confusion matrix**, for each box, read "when the true label is X, then __% of the time the models predicts Y".
89 |
90 | ### Paved vs. Unpaved
91 |
92 |
93 |
94 | Not bad! The model gets each category right over 95% of the time.
95 |
96 |
97 | ### Multiclass
98 |
99 |
100 | Given the imagery resolution, often obscuration of vegetation, and often incorrect truth labels this is impressive. The model clearly shows skill in predicting a wider range of classes than just paved vs. unpaved.
101 |
102 |
103 | ## Credits
104 |
105 | #### MaskCNN Paper
106 | Liu, L. Y. F., Liu, Y., & Zhu, H. (2020). Masked convolutional neural network for supervised learning problems. Stat, 9(1), e290.
107 |
108 | ## License
109 | [MIT](https://choosealicense.com/licenses/mit/) © 2024 Jonathan Dalrymple
--------------------------------------------------------------------------------
/figures/confusion_matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/confusion_matrix.png
--------------------------------------------------------------------------------
/figures/confusion_matrix_simplified.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/confusion_matrix_simplified.png
--------------------------------------------------------------------------------
/figures/naip_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/naip_example.png
--------------------------------------------------------------------------------
/figures/naip_masks_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/naip_masks_example.png
--------------------------------------------------------------------------------
/figures/osm_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/osm_logo.png
--------------------------------------------------------------------------------
/figures/rsc_diagram.drawio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/rsc_diagram.drawio.png
--------------------------------------------------------------------------------
/figures/rsc_road_small.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
52 |
--------------------------------------------------------------------------------
/figures/samples/correct/paved/sample_00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/paved/sample_00000.png
--------------------------------------------------------------------------------
/figures/samples/correct/paved/sample_00001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/paved/sample_00001.png
--------------------------------------------------------------------------------
/figures/samples/correct/paved/sample_00002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/paved/sample_00002.png
--------------------------------------------------------------------------------
/figures/samples/correct/paved/sample_00003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/paved/sample_00003.png
--------------------------------------------------------------------------------
/figures/samples/correct/paved/sample_00004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/paved/sample_00004.png
--------------------------------------------------------------------------------
/figures/samples/correct/unpaved/sample_00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/unpaved/sample_00000.png
--------------------------------------------------------------------------------
/figures/samples/correct/unpaved/sample_00001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/unpaved/sample_00001.png
--------------------------------------------------------------------------------
/figures/samples/correct/unpaved/sample_00002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/unpaved/sample_00002.png
--------------------------------------------------------------------------------
/figures/samples/correct/unpaved/sample_00003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/unpaved/sample_00003.png
--------------------------------------------------------------------------------
/figures/samples/correct/unpaved/sample_00004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/correct/unpaved/sample_00004.png
--------------------------------------------------------------------------------
/figures/samples/incorrect/paved/sample_00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/paved/sample_00000.png
--------------------------------------------------------------------------------
/figures/samples/incorrect/paved/sample_00001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/paved/sample_00001.png
--------------------------------------------------------------------------------
/figures/samples/incorrect/paved/sample_00002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/paved/sample_00002.png
--------------------------------------------------------------------------------
/figures/samples/incorrect/paved/sample_00003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/paved/sample_00003.png
--------------------------------------------------------------------------------
/figures/samples/incorrect/paved/sample_00004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/paved/sample_00004.png
--------------------------------------------------------------------------------
/figures/samples/incorrect/unpaved/sample_00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/unpaved/sample_00000.png
--------------------------------------------------------------------------------
/figures/samples/incorrect/unpaved/sample_00001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/unpaved/sample_00001.png
--------------------------------------------------------------------------------
/figures/samples/incorrect/unpaved/sample_00002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/unpaved/sample_00002.png
--------------------------------------------------------------------------------
/figures/samples/incorrect/unpaved/sample_00003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/unpaved/sample_00003.png
--------------------------------------------------------------------------------
/figures/samples/incorrect/unpaved/sample_00004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/figures/samples/incorrect/unpaved/sample_00004.png
--------------------------------------------------------------------------------
/notebooks/00_create_naip_on_aws_gpkg.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "### NAIP On AWS\n",
9 | "\n",
10 | "This Jupyter notebook provides code to scrape the [NAIP on AWS](https://registry.opendata.aws/naip/) manifest and create a GPKG file that provides a geospatial footprint for all available imagery.\n",
11 | "\n",
12 | "Then, if we are interested in a given OSM ID, it's easy to look up which NAIP images intersect this ID and download them directly."
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 1,
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "name": "stderr",
22 | "output_type": "stream",
23 | "text": [
24 | "100%|██████████| 1150/1150 [00:00<00:00, 125232.36it/s]\n"
25 | ]
26 | }
27 | ],
28 | "source": [
29 | "import pathlib\n",
30 | "\n",
31 | "from tqdm import tqdm\n",
32 | "from osgeo import ogr, osr\n",
33 | "ogr.UseExceptions()\n",
34 | "\n",
35 | "from rsc.common import aws_naip\n",
36 | "from rsc.common.aws_naip import AWS_PATH\n",
37 | "\n",
38 | "# Get NAIP manifest\n",
39 | "manifest = aws_naip.get_naip_manifest()\n",
40 | "\n",
41 | "# Filter out shapefiles to download\n",
42 | "shp = [e for e in manifest if e.split('.')[-1].lower() in \\\n",
43 | " ('shp', 'dbf', 'shx', 'prj', 'sbn')]\n",
44 | "\n",
45 | "# Fetch all the shapefiles\n",
46 | "for object_name in tqdm(shp):\n",
47 | " aws_naip.get_naip_file(object_name)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 2,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "def conv(s: str) -> str:\n",
57 | " \"\"\" Convert object paths as seen in manifest to those that might be seen in the shapefiles.\n",
58 | " It's silly they don't match.\"\"\"\n",
59 | " return '%s.tif' % '_'.join(s.split('_')[:6])\n",
60 | "\n",
61 | "# Read the manifest, and convert the TIF files to those that might be seen in the shapefile metadata\n",
62 | "with open(AWS_PATH / 'manifest.txt', 'r') as f:\n",
63 | " mani = {conv(pathlib.Path(p).stem): p for p in (e.strip() for e in f.readlines()) if p.endswith('.tif')}"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 3,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "# Do all the work!\n",
73 | "\n",
74 | "# Create SRS (EPSG:4326: WGS-84 decimal degrees)\n",
75 | "srs = osr.SpatialReference()\n",
76 | "srs.ImportFromEPSG(4326)\n",
77 | "srs.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)\n",
78 | "\n",
79 | "# Create GPKG file for writing\n",
80 | "driver: ogr.Driver = ogr.GetDriverByName('GPKG')\n",
81 | "ds_w: ogr.DataSource = driver.CreateDataSource(str(AWS_PATH / 'naip_on_aws.gpkg'))\n",
82 | "layer_w: ogr.Layer = ds_w.CreateLayer('footprints', srs=srs, geom_type=ogr.wkbPolygon)\n",
83 | "\n",
84 | "# Define output fields\n",
85 | "state_field = ogr.FieldDefn('STATE', ogr.OFTString)\n",
86 | "band_field = ogr.FieldDefn('BAND', ogr.OFTString)\n",
87 | "usgs_id_field = ogr.FieldDefn('USGSID', ogr.OFTString)\n",
88 | "src_img_date_field = ogr.FieldDefn('SRCIMGDATE', ogr.OFTString)\n",
89 | "filename_field = ogr.FieldDefn('FILENAME', ogr.OFTString)\n",
90 | "object_field = ogr.FieldDefn('OBJECT', ogr.OFTString)\n",
91 | "\n",
92 | "# Create output fields in layer\n",
93 | "layer_w.CreateField(state_field)\n",
94 | "layer_w.CreateField(band_field)\n",
95 | "layer_w.CreateField(usgs_id_field)\n",
96 | "layer_w.CreateField(src_img_date_field)\n",
97 | "layer_w.CreateField(filename_field)\n",
98 | "layer_w.CreateField(object_field)\n",
99 | "\n",
100 | "# Get layer feature definition to load in features\n",
101 | "feat_defn = layer_w.GetLayerDefn()\n",
102 | "\n",
103 | "# Loop through all fetched shapefiles\n",
104 | "for p in AWS_PATH.rglob('*.shp'):\n",
105 | "\n",
106 | " # Load them in OGR, get layer and spatial reference\n",
107 | " ds_r: ogr.DataSource = ogr.Open(str(p))\n",
108 | " layer_r: ogr.Layer = ds_r.GetLayer()\n",
109 | " srs_r: osr.SpatialReference = layer_r.GetSpatialRef()\n",
110 | " srs_r.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)\n",
111 | "\n",
112 | " # Loop throught the features in the layer\n",
113 | " for _ in range(layer_r.GetFeatureCount()):\n",
114 | " feat_r: ogr.Feature = layer_r.GetNextFeature()\n",
115 | "\n",
116 | " # Quickly crosscheck with manifest. Skip if not in there\n",
117 | " filename = feat_r.GetFieldAsString('FileName')\n",
118 | " filename_conv = conv(filename.split('.')[0])\n",
119 | " if not filename_conv in mani:\n",
120 | " continue\n",
121 | "\n",
122 | " # Parse remaining metadata\n",
123 | " try:\n",
124 | " state = feat_r.GetFieldAsString('ST')\n",
125 | " except RuntimeError:\n",
126 | " state = feat_r.GetFieldAsString('QUADST')\n",
127 | " band = feat_r.GetFieldAsString('Band')\n",
128 | " usgs_id = feat_r.GetFieldAsString('USGSID')\n",
129 | " src_img_date = feat_r.GetFieldAsString('SrcImgDate')\n",
130 | "\n",
131 | " # Fetch geometry and convert to desired spatial reference\n",
132 | " trans = osr.CoordinateTransformation(srs_r, srs)\n",
133 | " geom = ogr.CreateGeometryFromWkt(feat_r.GetGeometryRef().ExportToWkt())\n",
134 | " geom.Transform(trans)\n",
135 | "\n",
136 | " # Create our new feature\n",
137 | " feat_w = ogr.Feature(feat_defn)\n",
138 | " feat_w.SetGeometry(geom)\n",
139 | " feat_w.SetField('STATE', state)\n",
140 | " feat_w.SetField('BAND', band)\n",
141 | " feat_w.SetField('USGSID', usgs_id)\n",
142 | " feat_w.SetField('SRCIMGDATE', src_img_date)\n",
143 | " feat_w.SetField('FILENAME', filename)\n",
144 | " feat_w.SetField('OBJECT', mani[filename_conv])\n",
145 | "\n",
146 | " # Save!\n",
147 | " layer_w.CreateFeature(feat_w)\n",
148 | "\n",
149 | " # Cleanup features\n",
150 | " feat_w = None\n",
151 | " feat_r = None\n",
152 | "\n",
153 | " # Cleanup read dataset\n",
154 | " layer_r = None\n",
155 | " ds_r = None\n",
156 | "\n",
157 | "# Cleanup write dataset\n",
158 | "layer_w = None\n",
159 | "ds_w = None"
160 | ]
161 | }
162 | ],
163 | "metadata": {
164 | "kernelspec": {
165 | "display_name": "Python 3.10.8 64-bit",
166 | "language": "python",
167 | "name": "python3"
168 | },
169 | "language_info": {
170 | "codemirror_mode": {
171 | "name": "ipython",
172 | "version": 3
173 | },
174 | "file_extension": ".py",
175 | "mimetype": "text/x-python",
176 | "name": "python",
177 | "nbconvert_exporter": "python",
178 | "pygments_lexer": "ipython3",
179 | "version": "3.10.8"
180 | },
181 | "orig_nbformat": 4,
182 | "vscode": {
183 | "interpreter": {
184 | "hash": "4d102384ded633c24f1031e288c2ecf1ababc4ef37e402995ad37064232eefd1"
185 | }
186 | }
187 | },
188 | "nbformat": 4,
189 | "nbformat_minor": 2
190 | }
191 |
--------------------------------------------------------------------------------
/notebooks/01_explore_osm_surface.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "### Exploring OSM `surface` tags\n",
9 | "\n",
10 | "In this notebook, I perform a custom OSM Overpass API query for all \"driveable\" roads that have surface labels. My main curiosity is which `surface` tags appear the most, and if I can identify an easy set of `surface` tag values that will account for the overwhelming majority of roads in the US."
11 | ]
12 | },
13 | {
14 | "attachments": {},
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "First, I kick off a Docker container to host the OSM Overpass API. This query requires far too much data to use any Overpass API that is hosted online."
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 20,
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "name": "stdout",
28 | "output_type": "stream",
29 | "text": [
30 | "7b463d293ec20d6fab599eb3b2dd15f0463d89f6192df9ff25a9ec32d8317b6e\n"
31 | ]
32 | }
33 | ],
34 | "source": [
35 | "%%bash\n",
36 | "# https://hub.docker.com/r/wiktorn/overpass-api\n",
37 | "# http://localhost:12345/api/interpreter\n",
38 | "docker run \\\n",
39 | " -e OVERPASS_META=yes \\\n",
40 | " -e OVERPASS_MODE=init \\\n",
41 | " -e OVERPASS_PLANET_URL=file:///data/gis/us-latest.osm.bz2 \\\n",
42 | " -e OVERPASS_RULES_LOAD=10 \\\n",
43 | " -e OVERPASS_SPACE=55000000000 \\\n",
44 | " -e OVERPASS_MAX_TIMEOUT=86400 \\\n",
45 | " -v /data/gis:/data/gis \\\n",
46 | " -v /data/gis/overpass_db:/db \\\n",
47 | " -p 12345:80 \\\n",
48 | " -d --rm --name overpass_usa wiktorn/overpass-api:latest"
49 | ]
50 | },
51 | {
52 | "attachments": {},
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "Let's define a class that will allow us to perform an OSM Overpass API query for drivable road networks:"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 14,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "from rsc.osm.overpass_api import OSMOverpassQuery, OSMOverpassResult\n",
66 | "\n",
67 | "\n",
68 | "class OSMCustomOverpassQuery(OSMOverpassQuery):\n",
69 | " \"\"\" Custom OSM Overpass API query for (hopefully) drivable road networks \"\"\"\n",
70 | "\n",
71 | " __slots__ = ['_highway_tags']\n",
72 | "\n",
73 | " DEFAULT_HIGHWAY_TAGS = [\n",
74 | " 'motorway', 'motorway_link', 'motorway_junction', 'trunk',\n",
75 | " 'trunk_link', 'primary', 'primary_link', 'secondary', 'secondary_link',\n",
76 | " 'tertiary', 'tertiary_link', 'unclassified', 'residential'\n",
77 | " ]\n",
78 | "\n",
79 | " def __init__(self, **kwargs):\n",
80 | " super().__init__(**kwargs)\n",
81 | " self._highway_tags = kwargs.get('highway_tags',\n",
82 | " self.DEFAULT_HIGHWAY_TAGS)\n",
83 | "\n",
84 | " def perform_query(self) -> 'OSMOverpassResult':\n",
85 | " \"\"\" Perform an OSM Overpass API Request! \"\"\"\n",
86 | " return OSMOverpassResult(self._perform_query())\n",
87 | "\n",
88 | " @property\n",
89 | " def _query_str(self) -> str:\n",
90 | " return f\"\"\"\n",
91 | " [out:{self._format}]\n",
92 | " [timeout:{self._timeout}]\n",
93 | " [maxsize:2147483648];\n",
94 | " (way[\"highway\"]\n",
95 | " [\"area\"!~\"yes\"]\n",
96 | " [\"access\"!~\"private\"]\n",
97 | " [\"highway\"~\"{'|'.join(self._highway_tags)}\"]\n",
98 | " [\"motor_vehicle\"!~\"no\"]\n",
99 | " [\"motorcar\"!~\"no\"]\n",
100 | " [\"surface\"!~\"\"]\n",
101 | " [\"service\"!~\"alley|driveway|emergency_access|parking|parking_aisle|private\"]\n",
102 | " (poly:'{self._poly_query_str}');\n",
103 | " >;\n",
104 | " );\n",
105 | " out;\n",
106 | " \"\"\""
107 | ]
108 | },
109 | {
110 | "attachments": {},
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "Now we can perform the query. This is a *very* broad query and therefore takes quite a bit of time."
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 17,
120 | "metadata": {},
121 | "outputs": [
122 | {
123 | "name": "stdout",
124 | "output_type": "stream",
125 | "text": [
126 | "OSM data found at /data/gis/us_road_surface/us_w_road_surface.osm\n"
127 | ]
128 | }
129 | ],
130 | "source": [
131 | "import pathlib\n",
132 | "out_path = pathlib.Path('/data/gis/us_road_surface/us_w_road_surface.osm')\n",
133 | "\n",
134 | "if not out_path.is_file():\n",
135 | " # Setup custom query to local interpreter\n",
136 | " # Set a very long timeout\n",
137 | " q = OSMCustomOverpassQuery(format='xml', timeout=24 * 60 * 60)\n",
138 | " q.set_endpoint('http://localhost:12345/api/interpreter')\n",
139 | "\n",
140 | " # Use rough USA bounds for query\n",
141 | " with open('/data/gis/us_wkt.txt', 'r') as f:\n",
142 | " us_wkt = f.read()\n",
143 | " q.set_poly_from_wkt(us_wkt)\n",
144 | "\n",
145 | " # Perform query and save! This will take a long time.\n",
146 | " print('Performing query...')\n",
147 | " result = q.perform_query()\n",
148 | " print('Saving to file...')\n",
149 | " result.to_file(out_path)\n",
150 | "else:\n",
151 | " print('OSM data found at %s' % str(out_path))"
152 | ]
153 | },
154 | {
155 | "attachments": {},
156 | "cell_type": "markdown",
157 | "metadata": {},
158 | "source": [
159 | "At this point we can stop our Docker container hosting the Overpass API:"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 21,
165 | "metadata": {},
166 | "outputs": [
167 | {
168 | "name": "stdout",
169 | "output_type": "stream",
170 | "text": [
171 | "overpass_usa\n"
172 | ]
173 | }
174 | ],
175 | "source": [
176 | "%%bash\n",
177 | "docker stop overpass_usa"
178 | ]
179 | },
180 | {
181 | "attachments": {},
182 | "cell_type": "markdown",
183 | "metadata": {},
184 | "source": [
185 | "And now we use ogr2ogr to convert the OSM file we downloaded to a CSV file for easier GIS processing (the OGR OSM driver is very limited). This can be done in Python too, but the command line tool is easier for simple file conversions."
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 31,
191 | "metadata": {},
192 | "outputs": [
193 | {
194 | "name": "stdout",
195 | "output_type": "stream",
196 | "text": [
197 | "File found: /data/gis/us_road_surface/us_w_road_surface.gpkg\n"
198 | ]
199 | }
200 | ],
201 | "source": [
202 | "%%bash\n",
203 | "DATA_DIR=/data/gis/us_road_surface\n",
204 | "OSM_PATH=$DATA_DIR/us_w_road_surface.osm\n",
205 | "GPKG_PATH=$DATA_DIR/us_w_road_surface.gpkg\n",
206 | "if [ ! -f $GPKG_PATH ]; then\n",
207 | " echo \"Converting $OSM_PATH to $GPKG_PATH...\"\n",
208 | " ogr2ogr $GPKG_PATH $OSM_PATH lines\n",
209 | "else\n",
210 | " echo \"File found: $GPKG_PATH\"\n",
211 | "fi"
212 | ]
213 | },
214 | {
215 | "attachments": {},
216 | "cell_type": "markdown",
217 | "metadata": {},
218 | "source": [
219 | "Let's parse the GPKG file to understand what surface types we are dealing with:"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": 2,
225 | "metadata": {},
226 | "outputs": [
227 | {
228 | "name": "stdout",
229 | "output_type": "stream",
230 | "text": [
231 | "Loading & filtering dataset features...\n"
232 | ]
233 | },
234 | {
235 | "name": "stderr",
236 | "output_type": "stream",
237 | "text": [
238 | "100%|██████████| 2458128/2458128 [02:15<00:00, 18099.85it/s]\n"
239 | ]
240 | }
241 | ],
242 | "source": [
243 | "from osgeo import gdal, osr\n",
244 | "from tqdm import tqdm\n",
245 | "\n",
246 | "# Variable to store feature data that we load\n",
247 | "feature_data = []\n",
248 | "\n",
249 | "# WGS84 Spatial Reference: all of OSM is EPSG:4326\n",
250 | "srs_wgs84 = osr.SpatialReference()\n",
251 | "srs_wgs84.ImportFromEPSG(4326)\n",
252 | "\n",
253 | "# Dataset is ogr OSM parsed file with \"lines\" layer exported\n",
254 | "ds = gdal.OpenEx('/data/gis/us_road_surface/us_w_road_surface.gpkg')\n",
255 | "layer = ds.GetLayer()\n",
256 | "feature_count = layer.GetFeatureCount()\n",
257 | "print('Loading & filtering dataset features...')\n",
258 | "for idx in tqdm(range(feature_count)):\n",
259 | " # Get geometry, OSM ID, highway, and surface tag from each way\n",
260 | " feature = layer.GetNextFeature()\n",
261 | " highway = str(feature.GetField(2))\n",
262 | " wkt_str = feature.GetGeometryRef().ExportToWkt()\n",
263 | " osm_id = int(feature.GetField(0))\n",
264 | " other_tags = str(feature.GetField(8))\n",
265 | "\n",
266 | " # NOTE: parsing the misc. tags field is messy. This is about\n",
267 | " # as good as it gets.\n",
268 | " tags_dict = dict([[f.replace('\"', '') for f in e.split('\"=>\"')]\n",
269 | " for e in other_tags.split('\",\"')])\n",
270 | " surface_type = tags_dict.get('surface', 'unknown')\n",
271 | "\n",
272 | " # Add to the feature data\n",
273 | " feature_data.append([osm_id, wkt_str, highway, surface_type])\n",
274 | "\n",
275 | "# Close dataset\n",
276 | "layer = None\n",
277 | "ds = None"
278 | ]
279 | },
280 | {
281 | "attachments": {},
282 | "cell_type": "markdown",
283 | "metadata": {},
284 | "source": [
285 | "Empirically, if I take any `surface` labels that appear more than 1000 times, I get a reasonable set of labels that covers the fast majority of cases. Nice!"
286 | ]
287 | },
288 | {
289 | "cell_type": "code",
290 | "execution_count": 3,
291 | "metadata": {},
292 | "outputs": [
293 | {
294 | "data": {
295 | "text/plain": [
296 | "asphalt 1639190\n",
297 | "unpaved 266035\n",
298 | "paved 230889\n",
299 | "concrete 153371\n",
300 | "gravel 105499\n",
301 | "dirt 28363\n",
302 | "concrete:plates 14310\n",
303 | "compacted 8585\n",
304 | "paving_stones 2562\n",
305 | "ground 2353\n",
306 | "bricks 1751\n",
307 | "Name: surface, dtype: int64"
308 | ]
309 | },
310 | "execution_count": 3,
311 | "metadata": {},
312 | "output_type": "execute_result"
313 | }
314 | ],
315 | "source": [
316 | "import pandas as pd\n",
317 | "\n",
318 | "df = pd.DataFrame(feature_data, columns=['osm_id', 'wkt', 'highway', 'surface']).set_index('osm_id')\n",
319 | "unique_surface = df['surface'].value_counts()\n",
320 | "unique_surface = unique_surface[unique_surface > 1000]\n",
321 | "unique_surface"
322 | ]
323 | },
324 | {
325 | "attachments": {},
326 | "cell_type": "markdown",
327 | "metadata": {},
328 | "source": [
329 | "Let's see what percentage of `surface`-tagged roads this set of labels covers:"
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": 4,
335 | "metadata": {},
336 | "outputs": [
337 | {
338 | "name": "stdout",
339 | "output_type": "stream",
340 | "text": [
341 | "Filtered surface tags account for 99.8% of all driveable ways.\n"
342 | ]
343 | }
344 | ],
345 | "source": [
346 | "pc = unique_surface.sum() / len(feature_data)\n",
347 | "print(f'Filtered surface tags account for {pc:.1%} of all driveable ways.')"
348 | ]
349 | },
350 | {
351 | "attachments": {},
352 | "cell_type": "markdown",
353 | "metadata": {},
354 | "source": [
355 | "I can live with that.\n",
356 | "\n",
357 | "Let's save these filtered tags into a GPKG file for our classifier dataset prep."
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": null,
363 | "metadata": {},
364 | "outputs": [],
365 | "source": [
366 | "# Start a new dataset\n",
367 | "print('Saving filtered features...')\n",
368 | "driver = ogr.GetDriverByName('GPKG')\n",
369 | "ds = driver.CreateDataSource(str(filtered_gpkg_path))\n",
370 | "layer = ds.CreateLayer('roads', srs=srs_wgs84, geom_type=ogr.wkbLineString)\n",
371 | "\n",
372 | "# Define fields\n",
373 | "id_field = ogr.FieldDefn('osmid', ogr.OFTInteger64)\n",
374 | "highway_field = ogr.FieldDefn('highway', ogr.OFTString)\n",
375 | "surface_field = ogr.FieldDefn('surface', ogr.OFTString)\n",
376 | "for field in (id_field, highway_field, surface_field):\n",
377 | " layer.CreateField(field)\n",
378 | "\n",
379 | "# Add features\n",
380 | "feature_defn = layer.GetLayerDefn()\n",
381 | "for idx, (osm_id, wkt_str, highway,\n",
382 | " surface_type) in tqdm(enumerate(feature_data)):\n",
383 | "\n",
384 | " # New feature\n",
385 | " feat = ogr.Feature(feature_defn)\n",
386 | "\n",
387 | " # Set geometry\n",
388 | " geom = ogr.CreateGeometryFromWkt(wkt_str)\n",
389 | " feat.SetGeometry(geom)\n",
390 | "\n",
391 | " # Set fields\n",
392 | " feat.SetField('osmid', osm_id)\n",
393 | " feat.SetField('highway', highway)\n",
394 | " feat.SetField('surface', surface_type)\n",
395 | "\n",
396 | " # Flush\n",
397 | " layer.CreateFeature(feat)\n",
398 | " feat = None\n",
399 | "\n",
400 | "# Close dataset\n",
401 | "layer = None\n",
402 | "ds = None"
403 | ]
404 | }
405 | ],
406 | "metadata": {
407 | "kernelspec": {
408 | "display_name": "venv",
409 | "language": "python",
410 | "name": "python3"
411 | },
412 | "language_info": {
413 | "codemirror_mode": {
414 | "name": "ipython",
415 | "version": 3
416 | },
417 | "file_extension": ".py",
418 | "mimetype": "text/x-python",
419 | "name": "python",
420 | "nbconvert_exporter": "python",
421 | "pygments_lexer": "ipython3",
422 | "version": "3.10.10"
423 | },
424 | "orig_nbformat": 4
425 | },
426 | "nbformat": 4,
427 | "nbformat_minor": 2
428 | }
429 |
--------------------------------------------------------------------------------
/notebooks/03_dataset_to_gpkg.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "This is a quick notebook to allow me to plot the locations of all the images used in my dataset in QGIS."
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 3,
14 | "metadata": {},
15 | "outputs": [
16 | {
17 | "data": {
18 | "text/html": [
19 | "
\n",
20 | "\n",
33 | "
\n",
34 | " \n",
35 | " \n",
36 | " | \n",
37 | " highway | \n",
38 | " surface | \n",
39 | " lon | \n",
40 | " lat | \n",
41 | " object | \n",
42 | "
\n",
43 | " \n",
44 | " osm_id | \n",
45 | " | \n",
46 | " | \n",
47 | " | \n",
48 | " | \n",
49 | " | \n",
50 | "
\n",
51 | " \n",
52 | " \n",
53 | " \n",
54 | " 391020748 | \n",
55 | " primary | \n",
56 | " asphalt | \n",
57 | " -117.658360 | \n",
58 | " 33.596795 | \n",
59 | " ca/2020/60cm/rgbir/33117/m_3311727_ne_11_060_2... | \n",
60 | "
\n",
61 | " \n",
62 | " 14290622 | \n",
63 | " primary_link | \n",
64 | " asphalt | \n",
65 | " -115.147473 | \n",
66 | " 36.092396 | \n",
67 | " nv/2019/60cm/rgbir/36115/m_3611563_ne_11_060_2... | \n",
68 | "
\n",
69 | " \n",
70 | " 240534097 | \n",
71 | " motorway | \n",
72 | " asphalt | \n",
73 | " -78.739247 | \n",
74 | " 42.929866 | \n",
75 | " ny/2019/60cm/rgbir/42078/m_4207803_sw_17_060_2... | \n",
76 | "
\n",
77 | " \n",
78 | " 684838122 | \n",
79 | " secondary | \n",
80 | " asphalt | \n",
81 | " -112.012969 | \n",
82 | " 33.444280 | \n",
83 | " az/2017/60cm/rgbir/33112/m_3311240_ne_12_h_201... | \n",
84 | "
\n",
85 | " \n",
86 | " 13567313 | \n",
87 | " unclassified | \n",
88 | " asphalt | \n",
89 | " -88.843204 | \n",
90 | " 32.081660 | \n",
91 | " ms/2020/60cm/rgbir/32088/m_3208858_nw_16_060_2... | \n",
92 | "
\n",
93 | " \n",
94 | " ... | \n",
95 | " ... | \n",
96 | " ... | \n",
97 | " ... | \n",
98 | " ... | \n",
99 | " ... | \n",
100 | "
\n",
101 | " \n",
102 | " 13853468 | \n",
103 | " unclassified | \n",
104 | " unpaved | \n",
105 | " -117.039719 | \n",
106 | " 47.376109 | \n",
107 | " wa/2019/60cm/rgbir/47117/m_4711740_se_11_060_2... | \n",
108 | "
\n",
109 | " \n",
110 | " 14121032 | \n",
111 | " unclassified | \n",
112 | " unpaved | \n",
113 | " -98.947994 | \n",
114 | " 42.185050 | \n",
115 | " ne/2020/60cm/rgbir/42098/m_4209849_sw_14_060_2... | \n",
116 | "
\n",
117 | " \n",
118 | " 8757986 | \n",
119 | " residential | \n",
120 | " unpaved | \n",
121 | " -74.787238 | \n",
122 | " 44.285550 | \n",
123 | " ny/2019/60cm/rgbir/44074/m_4407442_se_18_060_2... | \n",
124 | "
\n",
125 | " \n",
126 | " 19717227 | \n",
127 | " residential | \n",
128 | " unpaved | \n",
129 | " -72.003888 | \n",
130 | " 44.995127 | \n",
131 | " vt/2018/60cm/rgbir/44072/m_4407208_ne_18_060_2... | \n",
132 | "
\n",
133 | " \n",
134 | " 14125476 | \n",
135 | " unclassified | \n",
136 | " unpaved | \n",
137 | " -98.536517 | \n",
138 | " 41.241572 | \n",
139 | " ne/2020/60cm/rgbir/41098/m_4109852_ne_14_060_2... | \n",
140 | "
\n",
141 | " \n",
142 | "
\n",
143 | "
46491 rows × 5 columns
\n",
144 | "
"
145 | ],
146 | "text/plain": [
147 | " highway surface lon lat \\\n",
148 | "osm_id \n",
149 | "391020748 primary asphalt -117.658360 33.596795 \n",
150 | "14290622 primary_link asphalt -115.147473 36.092396 \n",
151 | "240534097 motorway asphalt -78.739247 42.929866 \n",
152 | "684838122 secondary asphalt -112.012969 33.444280 \n",
153 | "13567313 unclassified asphalt -88.843204 32.081660 \n",
154 | "... ... ... ... ... \n",
155 | "13853468 unclassified unpaved -117.039719 47.376109 \n",
156 | "14121032 unclassified unpaved -98.947994 42.185050 \n",
157 | "8757986 residential unpaved -74.787238 44.285550 \n",
158 | "19717227 residential unpaved -72.003888 44.995127 \n",
159 | "14125476 unclassified unpaved -98.536517 41.241572 \n",
160 | "\n",
161 | " object \n",
162 | "osm_id \n",
163 | "391020748 ca/2020/60cm/rgbir/33117/m_3311727_ne_11_060_2... \n",
164 | "14290622 nv/2019/60cm/rgbir/36115/m_3611563_ne_11_060_2... \n",
165 | "240534097 ny/2019/60cm/rgbir/42078/m_4207803_sw_17_060_2... \n",
166 | "684838122 az/2017/60cm/rgbir/33112/m_3311240_ne_12_h_201... \n",
167 | "13567313 ms/2020/60cm/rgbir/32088/m_3208858_nw_16_060_2... \n",
168 | "... ... \n",
169 | "13853468 wa/2019/60cm/rgbir/47117/m_4711740_se_11_060_2... \n",
170 | "14121032 ne/2020/60cm/rgbir/42098/m_4209849_sw_14_060_2... \n",
171 | "8757986 ny/2019/60cm/rgbir/44074/m_4407442_se_18_060_2... \n",
172 | "19717227 vt/2018/60cm/rgbir/44072/m_4407208_ne_18_060_2... \n",
173 | "14125476 ne/2020/60cm/rgbir/41098/m_4109852_ne_14_060_2... \n",
174 | "\n",
175 | "[46491 rows x 5 columns]"
176 | ]
177 | },
178 | "execution_count": 3,
179 | "metadata": {},
180 | "output_type": "execute_result"
181 | }
182 | ],
183 | "source": [
184 | "# Import data\n",
185 | "\n",
186 | "import pathlib\n",
187 | "\n",
188 | "import pandas as pd\n",
189 | "\n",
190 | "features_path = pathlib.Path('/data/road_surface_classifier/features.csv')\n",
191 | "df_feat = pd.read_csv(features_path).set_index('osm_id')\n",
192 | "df_feat = df_feat.drop(columns=['wkt', 'x', 'y', 'ix', 'iy', 'length'])\n",
193 | "df_feat"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 7,
199 | "metadata": {},
200 | "outputs": [],
201 | "source": [
202 | "# Create dataset of points\n",
203 | "from osgeo import ogr, osr\n",
204 | "\n",
205 | "# Create SRS (EPSG:4326: WGS-84 decimal degrees)\n",
206 | "srs = osr.SpatialReference()\n",
207 | "srs.ImportFromEPSG(4326)\n",
208 | "\n",
209 | "driver: ogr.Driver = ogr.GetDriverByName('GPKG')\n",
210 | "ds: ogr.DataSource = driver.CreateDataSource('/data/road_surface_classifier/features_pts.gpkg')\n",
211 | "layer: ogr.Layer = ds.CreateLayer('data', srs=srs, geom_type=ogr.wkbPoint)\n",
212 | "\n",
213 | "osm_id_field = ogr.FieldDefn('osm_id', ogr.OFTInteger64)\n",
214 | "highway_field = ogr.FieldDefn('highway', ogr.OFTString)\n",
215 | "surface_field = ogr.FieldDefn('surface', ogr.OFTString)\n",
216 | "\n",
217 | "layer.CreateField(osm_id_field)\n",
218 | "layer.CreateField(highway_field)\n",
219 | "layer.CreateField(surface_field)\n",
220 | "\n",
221 | "feature_defn = layer.GetLayerDefn()\n",
222 | "\n",
223 | "for _, row in df_feat.iterrows():\n",
224 | " feat = ogr.Feature(feature_defn)\n",
225 | "\n",
226 | " pt = ogr.Geometry(ogr.wkbPoint)\n",
227 | " pt.AddPoint_2D(row['lon'], row['lat'])\n",
228 | "\n",
229 | " feat.SetGeometry(pt)\n",
230 | " feat.SetField('osm_id', row.name)\n",
231 | " feat.SetField('highway', row['highway'])\n",
232 | " feat.SetField('surface', row['surface'])\n",
233 | " layer.CreateFeature(feat)\n",
234 | " pt = None\n",
235 | " feat = None\n",
236 | "\n",
237 | "layer = None # type: ignore\n",
238 | "ds = None # type: ignore"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 6,
244 | "metadata": {},
245 | "outputs": [
246 | {
247 | "name": "stderr",
248 | "output_type": "stream",
249 | "text": [
250 | "Exception ignored in: \n",
251 | "Traceback (most recent call last):\n",
252 | " File \"/tmp/ipykernel_57433/3662002279.py\", line 16, in \n",
253 | "RuntimeError: sqlite3_exec(CREATE TABLE gpkg_extensions (table_name TEXT,column_name TEXT,extension_name TEXT NOT NULL,definition TEXT NOT NULL,scope TEXT NOT NULL,CONSTRAINT ge_tce UNIQUE (table_name, column_name, extension_name))) failed: attempt to write a readonly database\n"
254 | ]
255 | }
256 | ],
257 | "source": [
258 | "# Create dataset of polygons from imagery\n",
259 | "import json\n",
260 | "from osgeo import gdal, ogr, osr\n",
261 | "\n",
262 | "gdal.UseExceptions()\n",
263 | "ogr.UseExceptions()\n",
264 | "\n",
265 | "naip_img_path = pathlib.Path('/data/road_surface_classifier/imagery')\n",
266 | "assert naip_img_path.is_dir()\n",
267 | "\n",
268 | "# Create SRS (EPSG:4326: WGS-84 decimal degrees)\n",
269 | "srs = osr.SpatialReference()\n",
270 | "srs.ImportFromEPSG(4326)\n",
271 | "\n",
272 | "driver: ogr.Driver = ogr.GetDriverByName('GPKG')\n",
273 | "ds: ogr.DataSource = driver.CreateDataSource('/data/road_surface_classifier/features_polys.gpkg')\n",
274 | "layer: ogr.Layer = ds.CreateLayer('data', srs=srs, geom_type=ogr.wkbPolygon)\n",
275 | "\n",
276 | "osm_id_field = ogr.FieldDefn('osm_id', ogr.OFTInteger64)\n",
277 | "highway_field = ogr.FieldDefn('highway', ogr.OFTString)\n",
278 | "surface_field = ogr.FieldDefn('surface', ogr.OFTString)\n",
279 | "\n",
280 | "layer.CreateField(osm_id_field)\n",
281 | "layer.CreateField(highway_field)\n",
282 | "layer.CreateField(surface_field)\n",
283 | "\n",
284 | "feature_defn = layer.GetLayerDefn()\n",
285 | "\n",
286 | "for osm_id, row in df_feat.iterrows():\n",
287 | " img_path = naip_img_path / str('%d.tif' % osm_id)\n",
288 | " assert img_path.exists()\n",
289 | "\n",
290 | " wgs84_extent = gdal.Info(str(img_path), format='json')['wgs84Extent']\n",
291 | " poly = ogr.CreateGeometryFromJson(json.dumps(wgs84_extent))\n",
292 | "\n",
293 | " feat = ogr.Feature(feature_defn)\n",
294 | "\n",
295 | " feat.SetGeometry(poly)\n",
296 | " feat.SetField('osm_id', row.name)\n",
297 | " feat.SetField('highway', row['highway'])\n",
298 | " feat.SetField('surface', row['surface'])\n",
299 | " layer.CreateFeature(feat)\n",
300 | " poly = None\n",
301 | " feat = None\n",
302 | "\n",
303 | "layer = None # type: ignore\n",
304 | "ds = None # type: ignore"
305 | ]
306 | }
307 | ],
308 | "metadata": {
309 | "kernelspec": {
310 | "display_name": "Python 3.10.6 ('rsd_env')",
311 | "language": "python",
312 | "name": "python3"
313 | },
314 | "language_info": {
315 | "codemirror_mode": {
316 | "name": "ipython",
317 | "version": 3
318 | },
319 | "file_extension": ".py",
320 | "mimetype": "text/x-python",
321 | "name": "python",
322 | "nbconvert_exporter": "python",
323 | "pygments_lexer": "ipython3",
324 | "version": "3.10.6"
325 | },
326 | "orig_nbformat": 4,
327 | "vscode": {
328 | "interpreter": {
329 | "hash": "3c3e3338979283bc5980811c64bc074b42c7e88e72bf9f1fd3a7107f9ec2dee1"
330 | }
331 | }
332 | },
333 | "nbformat": 4,
334 | "nbformat_minor": 2
335 | }
336 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "road_surface_classifer"
3 | version = "0.9.0"
4 | description = "Automated surface tagging of roads in OpenStreetMap"
5 | readme = "README.md"
6 | requires-python = ">=3.7"
7 | license = { file = "LICENSE" }
8 | keywords = ["openstreetmap", "machine learning", "classifier"]
9 | authors = [{ name = "Jon Dalrymple", email = "j_dalrym2@hotmail.com" }]
10 | maintainers = [{ name = "Jon Dalrymple", email = "j_dalrym2@hotmail.com" }]
11 |
12 | # Classifiers help users find your project by categorizing it.
13 | #
14 | # For a list of valid classifiers, see https://pypi.org/classifiers/
15 | classifiers = [
16 | "Development Status :: 3 - Alpha",
17 | "Environment :: GPU",
18 | "Environment :: GPU :: NVIDIA CUDA",
19 | "Environment :: GPU :: NVIDIA CUDA :: 11",
20 | "Framework :: Jupyter",
21 | "Intended Audience :: Developers",
22 | "Intended Audience :: Science/Research",
23 | "License :: OSI Approved :: MIT License",
24 | "Natural Language :: English",
25 | "Operating System :: POSIX :: Linux",
26 | "Programming Language :: Python :: 3",
27 | "Programming Language :: Python :: 3.7",
28 | "Programming Language :: Python :: 3.8",
29 | "Programming Language :: Python :: 3.9",
30 | "Programming Language :: Python :: 3.10",
31 | "Programming Language :: Python :: 3.11",
32 | "Programming Language :: Python :: 3 :: Only",
33 | "Topic :: Scientific/Engineering",
34 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
35 | "Topic :: Scientific/Engineering :: GIS",
36 | "Topic :: Scientific/Engineering :: Image Processing",
37 | "Topic :: Scientific/Engineering :: Image Recognition",
38 | ]
39 |
40 | # This field lists other packages that your project depends on to run.
41 | # Any package you put here will be installed by pip when your project is
42 | # installed, so they must be valid existing projects.
43 | #
44 | # For an analysis of this field vs pip's requirements files see:
45 | # https://packaging.python.org/discussions/install-requires-vs-requirements/
46 | # TODO: double-check these + split for rsc submodules
47 | # TODO: For now, use requirements.txt / requirements.in
48 | dependencies = []
49 |
50 | # List additional groups of dependencies here (e.g. development
51 | # dependencies). Users will be able to install these using the "extras"
52 | # syntax, for example:
53 | #
54 | # $ pip install sampleproject[dev]
55 | #
56 | # Similar to `dependencies` above, these must be valid existing
57 | # projects.
58 | [project.optional-dependencies]
59 | dev = ["check-manifest"]
60 | test = ["coverage"]
61 |
62 | [project.urls]
63 | "Homepage" = "https://github.com/jdalrym2/road_surface_classifier"
64 |
65 | [tool.setuptools]
66 | packages = ["rsc"]
67 |
68 | [build-system]
69 | # These are the assumed default build requirements from pip:
70 | # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support
71 | requires = ["setuptools>=43.0.0", "wheel"]
72 | build-backend = "setuptools.build_meta"
73 |
--------------------------------------------------------------------------------
/requirements.in:
--------------------------------------------------------------------------------
1 | boto3
2 | numpy
3 | GDAL>3,<=3.8.3
4 | kornia
5 | matplotlib
6 | pandas
7 | Pillow
8 | pytorch_lightning
9 | optuna>=3.1.0
10 | mlflow
11 | requests
12 | scikit_learn
13 | scikit_image
14 | scipy
15 | torch
16 | torchvision
17 | tqdm
18 | ipykernel
19 | ipympl
20 | ipywidgets
21 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | #
2 | # This file is autogenerated by pip-compile with Python 3.10
3 | # by the following command:
4 | #
5 | # pip-compile --output-file=requirements.txt requirements.in
6 | #
7 | aiohttp==3.8.4
8 | # via fsspec
9 | aiosignal==1.3.1
10 | # via aiohttp
11 | alembic==1.10.4
12 | # via
13 | # mlflow
14 | # optuna
15 | asttokens==2.2.1
16 | # via stack-data
17 | async-timeout==4.0.2
18 | # via aiohttp
19 | attrs==23.1.0
20 | # via aiohttp
21 | backcall==0.2.0
22 | # via ipython
23 | blinker==1.6.2
24 | # via flask
25 | boto3==1.26.129
26 | # via -r requirements.in
27 | botocore==1.29.129
28 | # via
29 | # boto3
30 | # s3transfer
31 | certifi==2023.5.7
32 | # via requests
33 | charset-normalizer==3.1.0
34 | # via
35 | # aiohttp
36 | # requests
37 | click==8.1.3
38 | # via
39 | # databricks-cli
40 | # flask
41 | # mlflow
42 | cloudpickle==2.2.1
43 | # via mlflow
44 | cmaes==0.9.1
45 | # via optuna
46 | cmake==3.26.3
47 | # via triton
48 | colorlog==6.7.0
49 | # via optuna
50 | comm==0.1.3
51 | # via ipykernel
52 | contourpy==1.0.7
53 | # via matplotlib
54 | cycler==0.11.0
55 | # via matplotlib
56 | databricks-cli==0.17.6
57 | # via mlflow
58 | debugpy==1.6.7
59 | # via ipykernel
60 | decorator==5.1.1
61 | # via ipython
62 | docker==6.1.0
63 | # via mlflow
64 | entrypoints==0.4
65 | # via mlflow
66 | executing==1.2.0
67 | # via stack-data
68 | filelock==3.12.0
69 | # via
70 | # torch
71 | # triton
72 | flask==2.3.2
73 | # via mlflow
74 | fonttools==4.39.3
75 | # via matplotlib
76 | frozenlist==1.3.3
77 | # via
78 | # aiohttp
79 | # aiosignal
80 | fsspec[http]==2023.5.0
81 | # via pytorch-lightning
82 | gdal==3.8.3
83 | # via -r requirements.in
84 | gitdb==4.0.10
85 | # via gitpython
86 | gitpython==3.1.31
87 | # via mlflow
88 | greenlet==2.0.2
89 | # via sqlalchemy
90 | gunicorn==20.1.0
91 | # via mlflow
92 | idna==3.4
93 | # via
94 | # requests
95 | # yarl
96 | imageio==2.30.0
97 | # via scikit-image
98 | importlib-metadata==6.6.0
99 | # via mlflow
100 | ipykernel==6.22.0
101 | # via
102 | # -r requirements.in
103 | # ipywidgets
104 | ipympl==0.9.3
105 | # via -r requirements.in
106 | ipython==8.13.2
107 | # via
108 | # ipykernel
109 | # ipympl
110 | # ipywidgets
111 | ipython-genutils==0.2.0
112 | # via ipympl
113 | ipywidgets==8.0.6
114 | # via
115 | # -r requirements.in
116 | # ipympl
117 | itsdangerous==2.1.2
118 | # via flask
119 | jedi==0.18.2
120 | # via ipython
121 | jinja2==3.1.2
122 | # via
123 | # flask
124 | # mlflow
125 | # torch
126 | jmespath==1.0.1
127 | # via
128 | # boto3
129 | # botocore
130 | joblib==1.2.0
131 | # via scikit-learn
132 | jupyter-client==8.2.0
133 | # via ipykernel
134 | jupyter-core==5.3.0
135 | # via
136 | # ipykernel
137 | # jupyter-client
138 | jupyterlab-widgets==3.0.7
139 | # via ipywidgets
140 | kiwisolver==1.4.4
141 | # via matplotlib
142 | kornia==0.6.12
143 | # via -r requirements.in
144 | lazy-loader==0.2
145 | # via scikit-image
146 | lightning-utilities==0.8.0
147 | # via pytorch-lightning
148 | lit==16.0.3
149 | # via triton
150 | mako==1.2.4
151 | # via alembic
152 | markdown==3.4.3
153 | # via mlflow
154 | markupsafe==2.1.2
155 | # via
156 | # jinja2
157 | # mako
158 | # werkzeug
159 | matplotlib==3.7.1
160 | # via
161 | # -r requirements.in
162 | # ipympl
163 | # mlflow
164 | matplotlib-inline==0.1.6
165 | # via
166 | # ipykernel
167 | # ipython
168 | mlflow==2.3.1
169 | # via -r requirements.in
170 | mpmath==1.3.0
171 | # via sympy
172 | multidict==6.0.4
173 | # via
174 | # aiohttp
175 | # yarl
176 | nest-asyncio==1.5.6
177 | # via ipykernel
178 | networkx==3.1
179 | # via
180 | # scikit-image
181 | # torch
182 | numpy==1.24.3
183 | # via
184 | # -r requirements.in
185 | # cmaes
186 | # contourpy
187 | # imageio
188 | # ipympl
189 | # matplotlib
190 | # mlflow
191 | # optuna
192 | # pandas
193 | # pyarrow
194 | # pytorch-lightning
195 | # pywavelets
196 | # scikit-image
197 | # scikit-learn
198 | # scipy
199 | # tifffile
200 | # torchmetrics
201 | # torchvision
202 | nvidia-cublas-cu11==11.10.3.66
203 | # via
204 | # nvidia-cudnn-cu11
205 | # nvidia-cusolver-cu11
206 | # torch
207 | nvidia-cuda-cupti-cu11==11.7.101
208 | # via torch
209 | nvidia-cuda-nvrtc-cu11==11.7.99
210 | # via torch
211 | nvidia-cuda-runtime-cu11==11.7.99
212 | # via torch
213 | nvidia-cudnn-cu11==8.5.0.96
214 | # via torch
215 | nvidia-cufft-cu11==10.9.0.58
216 | # via torch
217 | nvidia-curand-cu11==10.2.10.91
218 | # via torch
219 | nvidia-cusolver-cu11==11.4.0.1
220 | # via torch
221 | nvidia-cusparse-cu11==11.7.4.91
222 | # via torch
223 | nvidia-nccl-cu11==2.14.3
224 | # via torch
225 | nvidia-nvtx-cu11==11.7.91
226 | # via torch
227 | oauthlib==3.2.2
228 | # via databricks-cli
229 | optuna==3.1.1
230 | # via -r requirements.in
231 | packaging==23.1
232 | # via
233 | # docker
234 | # ipykernel
235 | # kornia
236 | # lightning-utilities
237 | # matplotlib
238 | # mlflow
239 | # optuna
240 | # pytorch-lightning
241 | # scikit-image
242 | # torchmetrics
243 | pandas==2.0.1
244 | # via
245 | # -r requirements.in
246 | # mlflow
247 | parso==0.8.3
248 | # via jedi
249 | pexpect==4.8.0
250 | # via ipython
251 | pickleshare==0.7.5
252 | # via ipython
253 | pillow==9.5.0
254 | # via
255 | # -r requirements.in
256 | # imageio
257 | # ipympl
258 | # matplotlib
259 | # scikit-image
260 | # torchvision
261 | platformdirs==3.5.0
262 | # via jupyter-core
263 | prompt-toolkit==3.0.38
264 | # via ipython
265 | protobuf==4.22.4
266 | # via mlflow
267 | psutil==5.9.5
268 | # via ipykernel
269 | ptyprocess==0.7.0
270 | # via pexpect
271 | pure-eval==0.2.2
272 | # via stack-data
273 | pyarrow==11.0.0
274 | # via mlflow
275 | pygments==2.15.1
276 | # via ipython
277 | pyjwt==2.6.0
278 | # via databricks-cli
279 | pyparsing==3.0.9
280 | # via matplotlib
281 | python-dateutil==2.8.2
282 | # via
283 | # botocore
284 | # jupyter-client
285 | # matplotlib
286 | # pandas
287 | pytorch-lightning==2.0.2
288 | # via -r requirements.in
289 | pytz==2023.3
290 | # via
291 | # mlflow
292 | # pandas
293 | pywavelets==1.4.1
294 | # via scikit-image
295 | pyyaml==6.0
296 | # via
297 | # mlflow
298 | # optuna
299 | # pytorch-lightning
300 | pyzmq==25.0.2
301 | # via
302 | # ipykernel
303 | # jupyter-client
304 | querystring-parser==1.2.4
305 | # via mlflow
306 | requests==2.30.0
307 | # via
308 | # -r requirements.in
309 | # databricks-cli
310 | # docker
311 | # fsspec
312 | # mlflow
313 | # torchvision
314 | s3transfer==0.6.1
315 | # via boto3
316 | scikit-image==0.20.0
317 | # via -r requirements.in
318 | scikit-learn==1.2.2
319 | # via
320 | # -r requirements.in
321 | # mlflow
322 | scipy==1.10.1
323 | # via
324 | # -r requirements.in
325 | # mlflow
326 | # scikit-image
327 | # scikit-learn
328 | six==1.16.0
329 | # via
330 | # asttokens
331 | # databricks-cli
332 | # python-dateutil
333 | # querystring-parser
334 | smmap==5.0.0
335 | # via gitdb
336 | sqlalchemy==2.0.12
337 | # via
338 | # alembic
339 | # mlflow
340 | # optuna
341 | sqlparse==0.4.4
342 | # via mlflow
343 | stack-data==0.6.2
344 | # via ipython
345 | sympy==1.11.1
346 | # via torch
347 | tabulate==0.9.0
348 | # via databricks-cli
349 | threadpoolctl==3.1.0
350 | # via scikit-learn
351 | tifffile==2023.4.12
352 | # via scikit-image
353 | torch==2.0.0
354 | # via
355 | # -r requirements.in
356 | # kornia
357 | # pytorch-lightning
358 | # torchmetrics
359 | # torchvision
360 | # triton
361 | torchmetrics==0.11.4
362 | # via pytorch-lightning
363 | torchvision==0.15.1
364 | # via -r requirements.in
365 | tornado==6.3.1
366 | # via
367 | # ipykernel
368 | # jupyter-client
369 | tqdm==4.65.0
370 | # via
371 | # -r requirements.in
372 | # optuna
373 | # pytorch-lightning
374 | traitlets==5.9.0
375 | # via
376 | # comm
377 | # ipykernel
378 | # ipympl
379 | # ipython
380 | # ipywidgets
381 | # jupyter-client
382 | # jupyter-core
383 | # matplotlib-inline
384 | triton==2.0.0
385 | # via torch
386 | typing-extensions==4.5.0
387 | # via
388 | # alembic
389 | # lightning-utilities
390 | # pytorch-lightning
391 | # sqlalchemy
392 | # torch
393 | tzdata==2023.3
394 | # via pandas
395 | urllib3==1.26.15
396 | # via
397 | # botocore
398 | # docker
399 | # requests
400 | wcwidth==0.2.6
401 | # via prompt-toolkit
402 | websocket-client==1.5.1
403 | # via docker
404 | werkzeug==2.3.3
405 | # via flask
406 | wheel==0.40.0
407 | # via
408 | # nvidia-cublas-cu11
409 | # nvidia-cuda-cupti-cu11
410 | # nvidia-cuda-runtime-cu11
411 | # nvidia-curand-cu11
412 | # nvidia-cusparse-cu11
413 | # nvidia-nvtx-cu11
414 | widgetsnbextension==4.0.7
415 | # via ipywidgets
416 | yarl==1.9.2
417 | # via aiohttp
418 | zipp==3.15.0
419 | # via importlib-metadata
420 |
421 | # The following packages are considered to be unsafe in a requirements file:
422 | # setuptools
423 |
--------------------------------------------------------------------------------
/rsc/artifacts/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """ RSC submodule to simplify artifact generation during inference over a dataset """
4 | import pathlib
5 |
6 | import torch
7 |
8 | # Get PyTorch device to use
9 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10 |
11 |
12 | def find_best_model(results_dir: pathlib.Path) -> pathlib.Path:
13 | """
14 | Find the best model checkpoint (by validation loss) in a results directory.
15 |
16 | Args:
17 | results_dir (pathlib.Path): Results directory
18 |
19 | Returns:
20 | pathlib.Path: Path to checkpoint with minimum validation loss.
21 | """
22 | # Parse checkpoints from results dir
23 | paths = list(results_dir.glob('*.ckpt'))
24 | path_stems = [e.stem for e in paths]
25 | path_metrics = [e.split('-') for e in path_stems]
26 | # Exact validation losses from filenames
27 | val_losses = []
28 | for idx, metrics in enumerate(path_metrics):
29 | for metric in metrics:
30 | if metric.startswith('val_loss'):
31 | val_loss = float(metric.split('=')[-1])
32 | val_losses.append((idx, val_loss))
33 |
34 | # Find path that has minimum validation loss
35 | # By reversing val losses, if there is a tie
36 | # we fetch the last one
37 | min_idx, _ = min(reversed(val_losses), key=lambda v: v[1])
38 |
39 | return paths[min_idx]
40 |
41 |
42 | from .base import ArtifactGenerator, ArtifactHandler
--------------------------------------------------------------------------------
/rsc/artifacts/__main__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """ Generate artifacts for a model """
4 |
5 | import pathlib
6 | import argparse
7 |
8 | import torch
9 | from torch.utils.data import DataLoader
10 |
11 | from ..train.plmcnn import PLMaskCNN
12 | from ..train.dataset import RoadSurfaceDataset
13 | from ..train.preprocess import PreProcess
14 |
15 | from . import find_best_model
16 | from .base import ArtifactGenerator
17 | from .confusion_matrix_handler import ConfusionMatrixHandler
18 | from .accuracy_obsc_handler import AccuracyObscHandler
19 | from .obsc_compare_handler import ObscCompareHandler
20 | from .samples_handler import SamplesHandler
21 | from .auc_handler import AUCHandler
22 |
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser(description=__doc__)
26 | parser.add_argument('-i',
27 | '--input-ckpt',
28 | type=str,
29 | required=True,
30 | help='Path to input checkpoint file. '
31 | 'If given a directory, attempts to '
32 | 'find the best model in the directory.')
33 | parser.add_argument('-d',
34 | '--dataset-csv',
35 | type=str,
36 | required=True,
37 | help='Path to dataset CSV')
38 | parser.add_argument('-o',
39 | '--output-path',
40 | type=str,
41 | required=True,
42 | help='Path to output directory for artifacts. '
43 | 'If it does not exist, one will be created.')
44 | parser.add_argument('-m',
45 | '--model-path',
46 | type=str,
47 | required=False,
48 | help='Path to model PTH file, '
49 | 'if PLMaskCNN is not the model')
50 | parser.add_argument('-c',
51 | '--count',
52 | type=int,
53 | required=False,
54 | default=-1,
55 | help='Maximum of images to do '
56 | 'inference on. '
57 | '-1: inference all images')
58 | parser.add_argument('--batch-size',
59 | type=int,
60 | required=False,
61 | default=64,
62 | help='Set dataloader batch size.')
63 | parser.add_argument('--num-workers',
64 | type=int,
65 | required=False,
66 | default=16,
67 | help='Set dataloader worker count.')
68 | parser.add_argument('--no-shuffle',
69 | action='store_true',
70 | help='If specified, do not shuffle '
71 | 'the dataset when loading.')
72 | parser.add_argument('--raise-on-error',
73 | action='store_true',
74 | help='If specified, stop processing when '
75 | 'encountering an error.')
76 |
77 | return parser.parse_args()
78 |
79 |
80 | if __name__ == '__main__':
81 |
82 | pargs = parse_args()
83 |
84 | # Parse relevant inputs up front
85 | csv_path = pathlib.Path(pargs.dataset_csv)
86 | assert csv_path.is_file()
87 |
88 | # Determine model checkpoint
89 | ckpt_path = pathlib.Path(pargs.input_ckpt)
90 | if ckpt_path.is_file():
91 | # Load the checkpoint as-is
92 | pass
93 | elif ckpt_path.is_dir():
94 | # Find the best model from the directory
95 | ckpt_path = find_best_model(ckpt_path)
96 | print('Found best model: %s' % str(ckpt_path))
97 | else:
98 | raise ValueError(
99 | 'Could not find checkpoint path: %s. Must be a file or directory.'
100 | % str(ckpt_path))
101 |
102 | # Load model
103 | if pargs.model_path is None:
104 | model = PLMaskCNN.load_from_checkpoint(ckpt_path)
105 | else:
106 | model_path = pathlib.Path(pargs.model_path)
107 | assert model_path.is_file()
108 | model: PLMaskCNN = torch.load(model_path)
109 | model.load_from_checkpoint(ckpt_path)
110 |
111 | # Put model in eval mode
112 | model.eval()
113 |
114 | # Construct dataset
115 | val_ds = RoadSurfaceDataset(csv_path,
116 | transform=PreProcess(),
117 | n_channels=model.nc,
118 | limit=pargs.count)
119 |
120 | # Construct dataloader
121 | val_dl = DataLoader(val_ds,
122 | num_workers=pargs.num_workers,
123 | batch_size=pargs.batch_size,
124 | shuffle=not pargs.no_shuffle)
125 |
126 | # Parse save directory, create if not there
127 | save_dir = pathlib.Path(pargs.output_path)
128 | if not save_dir.is_dir():
129 | save_dir.mkdir(parents=False)
130 |
131 | # Generate artifacts from model
132 | generator = ArtifactGenerator(save_dir, model, val_dl)
133 | generator.add_handler(ConfusionMatrixHandler(simple=True))
134 | generator.add_handler(ConfusionMatrixHandler(simple=False))
135 | generator.add_handler(AccuracyObscHandler())
136 | generator.add_handler(ObscCompareHandler())
137 | generator.add_handler(SamplesHandler())
138 | # generator.add_handler(AUCHandler()) # disabled for multiclass
139 | generator.run(raise_on_error=pargs.raise_on_error)
140 |
--------------------------------------------------------------------------------
/rsc/artifacts/accuracy_obsc_handler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import pathlib
5 | from typing import Any, Sequence
6 |
7 | import numpy as np
8 | import matplotlib
9 | import matplotlib.pyplot as plt
10 |
11 | from torch.utils.data import DataLoader
12 |
13 | from .base import ArtifactHandler
14 |
15 | # Use non-GUI backend
16 | matplotlib.use('Agg')
17 |
18 |
19 | class AccuracyObscHandler(ArtifactHandler):
20 | """ Handler class to plot accuracy vs obscuration """
21 |
22 | def __init__(self):
23 | super().__init__()
24 |
25 | self.labels: list[str] | None = None
26 | self.y_true_l = []
27 | self.acc_l = []
28 | self.y_true_obsc_l = []
29 |
30 | def start(self, model: Any, dataloader: DataLoader) -> None:
31 |
32 | # Try to get labels from model
33 | self.labels = model.__dict__.get('labels')
34 |
35 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None:
36 |
37 | _, features = dl_iter
38 | features = features.cpu().detach().numpy()
39 |
40 | _, pred = model_out
41 | pred = pred.cpu().detach().numpy()
42 |
43 | # Get predicted label as argmax
44 | this_y_pred = np.argmax(pred[..., :-1], axis=1)
45 | this_y_true = np.argmax(features[..., :-1], axis=1)
46 | self.y_true_l.append(this_y_true)
47 | self.acc_l.append((this_y_pred == this_y_true).astype(int))
48 | self.y_true_obsc_l.append(features[..., -1])
49 |
50 | @staticmethod
51 | def compute_scores(
52 | n_labels: int,
53 | acc: np.ndarray,
54 | y_true: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]:
55 | """
56 | Compute accuracy scores and relevant data for plotting.
57 |
58 | Args:
59 | n_labels (int): Number of labels in the dataset
60 | acc (np.ndarray): Boolean array where True -> correct result
61 | y_true (np.ndarray): The true label array. Must be the same size as `acc`
62 |
63 | Returns:
64 | tuple[float, np.ndarray, np.ndarray]: Output accuracy scores and relevant data. All in range [0, 1].
65 | - `acc_all`: Accuracy over all data in `y_true`
66 | - `acc[label]`: Accuracy for [label] roads in `y_true`
67 | - `pc[label]`: Proportion of [label] in `y_true`
68 | """
69 | assert len(acc) == len(y_true)
70 |
71 | # Count total + num correct
72 | num_total = len(acc)
73 | num_correct = np.count_nonzero(acc)
74 |
75 | # Compute accuracy scores
76 | try:
77 | acc_all = num_correct / num_total
78 | except ZeroDivisionError:
79 | acc_all = np.nan
80 |
81 | # Pre-allocate arrays for accuracy and proportion
82 | acc_l = np.zeros((n_labels,))
83 | pc = np.zeros((n_labels,))
84 |
85 | for label in range(n_labels):
86 | # Find indices relevant to the label
87 | idx_label = np.where(y_true == label)[0]
88 |
89 | # Compute accuracy of this label
90 | try:
91 | acc_l[label] = np.count_nonzero(acc[idx_label]) / len(idx_label)
92 | except ZeroDivisionError:
93 | acc_l[label] = np.nan
94 |
95 | # Compute proportion of this label
96 | try:
97 | pc[label] = len(idx_label) / len(acc)
98 | except ZeroDivisionError:
99 | pc[label] = np.nan
100 |
101 | return acc_all, acc_l, pc
102 |
103 | def save(self, output_dir) -> tuple[pathlib.Path, pathlib.Path]:
104 |
105 | assert self.labels is not None
106 |
107 | # Aggregate and organize
108 | y_true = np.concatenate(self.y_true_l)
109 | acc = np.concatenate(self.acc_l)
110 | y_true_obsc = np.concatenate(self.y_true_obsc_l)
111 |
112 | # Bins for obscuration: 0 -> 1
113 | bins = np.linspace(0, 1, 20 + 1)
114 |
115 | # Compute accuracy scores in bins
116 | acc_plt_b, counts_b = [], [] # binned
117 | acc_plt_c, counts_c = [], [] # cumulative
118 | for obsc_min, obsc_max in zip(bins[:-1], bins[1:]):
119 | # Get indices of interest (indices in bin + cumulative)
120 | idx_b = np.where((obsc_min < y_true_obsc)
121 | & (y_true_obsc <= obsc_max))[0]
122 | idx_c = np.where(y_true_obsc <= obsc_max)[0]
123 |
124 | # Compute accuracy scores and add to plot data:
125 | # Binned
126 | acc_all, acc_l, pc = self.compute_scores(
127 | len(self.labels), acc[idx_b], y_true[idx_b])
128 | acc_plt_b.append(
129 | [e * 100 for e in (acc_all, *acc_l)])
130 | counts_b.append(pc * 100)
131 |
132 | # Cumulative
133 | acc_all, acc_l, pc = self.compute_scores(
134 | len(self.labels), acc[idx_c], y_true[idx_c])
135 | acc_plt_c.append(
136 | [e * 100 for e in (acc_all, *acc_l)])
137 | counts_c.append(pc * 100)
138 |
139 | # Create the plots!
140 | # Binned
141 | fig, ax = plt.subplots(2, 1, sharex=True, figsize=(12, 9))
142 | ax[0].plot(bins[1:] * 100, [e[0] for e in acc_plt_b],
143 | 'k-*',
144 | linewidth=2)
145 | for i in range(1, len(acc_plt_b[0])):
146 | ax[0].plot(bins[1:] * 100, [e[i] for e in acc_plt_b], '-*')
147 | ax[0].set_title('Prediction Accuracy vs. Obscuration (Binned)')
148 | ax[0].set_ylim(None, 100) # type: ignore
149 | ax[0].set_ylabel(r'Accuracy [%]')
150 | ax[0].grid()
151 | ax[1].plot(bins[1:] * 100, counts_b, '-*')
152 | ax[1].set_title('Label Proportion vs. Obscuration (Binned)')
153 | ax[1].set_ylim(0, 100)
154 | ax[1].grid()
155 | ax[1].set_ylabel(f'Proportion [%]')
156 | ax[1].set_xlabel(r'Obscuration [%]')
157 | fig.legend(('All', *[e.title() for e in self.labels]))
158 | binned_path = output_dir / 'acc_obsc_plot_binned.png'
159 | fig.savefig(str(binned_path))
160 | plt.close(fig)
161 |
162 | # Cumulative
163 | fig, ax = plt.subplots(2, 1, sharex=True, figsize=(12, 9))
164 | ax[0].plot(bins[1:] * 100, [e[0] for e in acc_plt_c],
165 | 'k-*',
166 | linewidth=2)
167 | for i in range(1, len(acc_plt_c[0])):
168 | ax[0].plot(bins[1:] * 100, [e[i] for e in acc_plt_c], '-*')
169 | ax[0].set_title('Prediction Accuracy vs. Obscuration (Cumulative)')
170 | ax[0].set_ylim(None, 100) # type: ignore
171 | ax[0].set_ylabel(r'Accuracy [%]')
172 | ax[0].grid()
173 | ax[1].plot(bins[1:] * 100, counts_c, '-*')
174 | ax[1].set_title('Label Proportion vs. Obscuration (Cumulative)')
175 | ax[1].set_ylim(0, 100)
176 | ax[1].grid()
177 | ax[1].set_ylabel(f'Proportion [%]')
178 | ax[1].set_xlabel(r'Obscuration [%]')
179 | fig.legend(('All', *[e.title() for e in self.labels]))
180 | cumul_path = output_dir / 'acc_obsc_plot_cumul.png'
181 | fig.savefig(str(cumul_path))
182 | plt.close(fig)
183 |
184 | return binned_path, cumul_path
185 |
--------------------------------------------------------------------------------
/rsc/artifacts/auc_handler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import pathlib
5 | from typing import Any, Sequence
6 |
7 | import numpy as np
8 | import matplotlib
9 | import matplotlib.pyplot as plt
10 | import sklearn.metrics
11 |
12 | import torch
13 | from torch.utils.data import DataLoader
14 |
15 | from .base import ArtifactHandler
16 |
17 | # Use non-GUI backend
18 | matplotlib.use('Agg')
19 |
20 |
21 | class AUCHandler(ArtifactHandler):
22 | """ Handler class to plot accuracy vs obscuration """
23 |
24 | def __init__(self):
25 | super().__init__()
26 |
27 | self.y_true_l = [] # true labels
28 | self.y_pred_c = [] # confidences
29 | self.roc_auc = None
30 |
31 | def start(self, model: Any, dataloader: DataLoader) -> None:
32 | pass
33 |
34 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None:
35 |
36 | _, features = dl_iter
37 | features = features.cpu().detach().numpy()
38 |
39 | _, pred = model_out
40 | pred = torch.softmax(pred, dim=1)
41 | pred = pred.cpu().detach().numpy()
42 |
43 | # Get predicted label as argmax
44 | this_y_true = np.argmax(features[..., :-1], axis=1)
45 | self.y_true_l.append(this_y_true)
46 | self.y_pred_c.append(pred[..., :-1])
47 |
48 | def save(self, output_dir) -> pathlib.Path:
49 |
50 | y_true_l = np.concatenate(self.y_true_l)
51 | # probability of the class with the *greater* / *positive* label
52 | y_pred_c = np.concatenate(self.y_pred_c)[:, 1]
53 |
54 | fig, ax = plt.subplots()
55 | ax.grid()
56 | disp = sklearn.metrics.RocCurveDisplay.from_predictions(y_true_l,
57 | y_pred_c,
58 | ax=ax)
59 | self.roc_auc = disp.roc_auc # type: ignore
60 | ax.set_title('ROC Curve')
61 |
62 | plt_path = output_dir / 'roc_curve.png'
63 | fig.savefig(str(plt_path))
64 | plt.close(fig)
65 |
66 | return plt_path
67 |
--------------------------------------------------------------------------------
/rsc/artifacts/base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import pathlib
5 | import traceback
6 | from abc import ABC, abstractmethod
7 | from typing import Any, Sequence
8 |
9 | from tqdm import tqdm
10 |
11 | from torch.utils.data import DataLoader
12 |
13 | from . import device
14 |
15 |
16 | class ArtifactHandler(ABC):
17 | """ Abstract class used to generate an artifact from an inference result """
18 |
19 | def __init__(self):
20 | pass
21 |
22 | @abstractmethod
23 | def start(self, model: Any, dataloader: DataLoader) -> None:
24 | """
25 | Handler startup code. Do any initalization here.
26 |
27 | Args:
28 | model (Any): Model that will be used
29 | dataloader (DataLoader): Dataloader that will be used
30 | """
31 | pass
32 |
33 | @abstractmethod
34 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None:
35 | """
36 | Called each time a batch is inferenced. Collect data for
37 | the artifact here.
38 |
39 | Args:
40 | dl_iter (Sequence): Output iterable from the dataloader.
41 | model_out (Sequence): Output iterable from `model.predict`
42 | """
43 | pass
44 |
45 | @abstractmethod
46 | def save(self, output_dir: pathlib.Path) -> pathlib.Path:
47 | """
48 | Given the collected state from on_iter, generate and
49 | save the relevant artifact.
50 |
51 | Args:
52 | output_dir (pathlib.Path): Output directory to artifact
53 |
54 | Returns:
55 | pathlib.Path: Path to saved artifact
56 | """
57 | pass
58 |
59 |
60 | class ArtifactGenerator:
61 | """ Top-level class to handle multiple artifact generation """
62 |
63 | def __init__(self, output_dir: str | pathlib.Path, model: Any,
64 | dataloader: DataLoader):
65 | """
66 | Create an artifact generator
67 |
68 | Args:
69 | output_dir (str | pathlib.Path): Output directory to dump artifacts
70 | model (Any): RSC model for inference
71 | dataloader (DataLoader): Dataloader to run inference over
72 | """
73 | self.handlers: list[ArtifactHandler] = []
74 | self.is_active: list[bool] = []
75 | self.output_dir = pathlib.Path(output_dir)
76 | self.output_dir.mkdir(parents=False, exist_ok=True)
77 | self.model = model
78 | self.model.to(device)
79 | self.dataloader = dataloader
80 |
81 | def add_handler(self, handler: ArtifactHandler):
82 | """
83 | Add a handler to the generator. All handlers should be added
84 | before `run` is called.
85 |
86 | Args:
87 | handler (ArtifactHandler): Handler to add to the generator.
88 |
89 | Raises:
90 | ValueError: If the handler is not a subclass of `ArtifactHandler`
91 | """
92 | if not issubclass(handler.__class__, ArtifactHandler):
93 | raise ValueError(
94 | f'Input handler must be a subclass of {ArtifactHandler.__name__}'
95 | )
96 | self.handlers.append(handler)
97 |
98 | def _exec(self, idx, func, raise_on_error: bool) -> Any:
99 | """ Wrapper call to work with the handler. Provides
100 | a failsafe in the case the handler throws an exception
101 | such that inference can continue """
102 | # Skip an inactive handler
103 | if not self.is_active[idx]:
104 | return
105 | try:
106 | # Attempt call
107 | return func()
108 | except Exception:
109 | # If we failed, raise if we want
110 | if raise_on_error:
111 | raise
112 | # Otherwise, print the exception and deactivate
113 | # the handler
114 | traceback.print_exc()
115 | print(f'Deactivating handler {idx:d}...')
116 | self.is_active[idx] = False
117 |
118 | def run(self, raise_on_error: bool = False):
119 | """
120 | Run inference and pass data to each handler.
121 |
122 | Args:
123 | raise_on_error (bool, optional): If any handler throws an error,
124 | whether or not to raise. Defaults to False.
125 | """
126 |
127 | # Ready?
128 | self.is_active = [True for _ in self.handlers]
129 |
130 | # Set...
131 | for idx, h in enumerate(self.handlers):
132 | self._exec(idx, lambda: h.start(self.model, self.dataloader),
133 | raise_on_error)
134 |
135 | # Go!
136 | for dl_iter in tqdm(iter(self.dataloader)):
137 | x, _ = dl_iter
138 |
139 | # Extract just the image + location mask for inference
140 | # i.e. strip the last (probmask) channel
141 | x = x[:, :-1, :, :].to(device)
142 |
143 | # Get prediction from model
144 | model_out = self.model(x)
145 |
146 | # Run the handlers
147 | for idx, h in enumerate(self.handlers):
148 | self._exec(idx, lambda: h.on_iter(dl_iter, model_out),
149 | raise_on_error)
150 |
151 | # Finalize
152 | for idx, h in enumerate(self.handlers):
153 | self._exec(idx, lambda: h.save(self.output_dir), raise_on_error)
154 |
--------------------------------------------------------------------------------
/rsc/artifacts/confusion_matrix_handler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import pathlib
5 | from typing import Any, Sequence
6 |
7 | import numpy as np
8 | import matplotlib
9 | import matplotlib.pyplot as plt
10 | import matplotlib.cm as cm
11 | from torch.utils.data import DataLoader
12 | from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
13 |
14 | from .base import ArtifactHandler
15 |
16 | # Use non-GUI backend
17 | matplotlib.use('Agg')
18 |
19 |
20 | class ConfusionMatrixHandler(ArtifactHandler):
21 | """ Handler class to generate a confusion matrix """
22 |
23 | def __init__(self, simple=False):
24 | super().__init__()
25 |
26 | # Adds "simple mode", where we are only
27 | # looking at paved / unpaved roads
28 | # I admit this short-circuits the
29 | # generalizability of this class, but it's
30 | # worth it
31 | self.simple = simple
32 |
33 | self.labels: list[str] | None = None
34 | self.y_true_l: list[Any] = []
35 | self.y_pred_l: list[Any] = []
36 |
37 | def start(self, model: Any, dataloader: DataLoader) -> None:
38 |
39 | # Try to get labels from model
40 | self.labels = model.__dict__.get('labels')
41 | if self.simple:
42 | self.labels = ['paved', 'unpaved']
43 |
44 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None:
45 |
46 | # Parse iterable
47 | _, features = dl_iter
48 |
49 | # We extract just the image + location mask
50 | y_true = features.numpy()[..., :-1]
51 | if self.simple:
52 | y_true = np.sum(np.stack((y_true[..., 0:4], y_true[..., 4:8]), axis=1), axis=-1)
53 | y_true = np.argmax(y_true, axis=1)
54 | self.y_true_l.append(y_true)
55 |
56 | # Get prediction from model
57 | _, pred = model_out
58 | pred = pred.cpu().detach().numpy()
59 |
60 | # Lazy-init labels if we couldn't get them from the model
61 | if self.labels is None:
62 | self.labels = [f'Class {n+1:d}' for n in range(len(pred))]
63 |
64 | # Get predicted label as argmax
65 | if self.simple:
66 | y_pred = np.sum(np.stack((pred[..., 0:4], pred[..., 4:8]), axis=1), axis=-1)
67 | else:
68 | y_pred = pred[..., :-1]
69 | y_pred = np.argmax(y_pred, axis=1)
70 | self.y_pred_l.append(y_pred)
71 |
72 | def save(self, output_dir) -> pathlib.Path:
73 |
74 | # Aggregate and organize
75 | y_true = np.concatenate(self.y_true_l)
76 | y_pred = np.concatenate(self.y_pred_l)
77 |
78 | # Generate and save the confusion matrix
79 | c = ConfusionMatrixDisplay(confusion_matrix(y_true,
80 | y_pred,
81 | normalize='true'),
82 | display_labels=self.labels)
83 | if self.simple:
84 | output_path = output_dir / 'confusion_matrix_simple.png'
85 | else:
86 | output_path = output_dir / 'confusion_matrix.png'
87 | fig, ax = plt.subplots(figsize=(8, 8))
88 | fig.subplots_adjust(left=0.2, bottom=0.2)
89 | try:
90 | c.plot(ax=ax, cmap=cm.Blues, # type: ignore
91 | xticks_rotation='vertical')
92 | except ValueError:
93 | print('Detected issue with plot! This is likely due to a mismatch of class labels and true labels.')
94 | raise
95 | fig.savefig(str(output_path))
96 | plt.close(fig)
97 |
98 | return output_path
--------------------------------------------------------------------------------
/rsc/artifacts/obsc_compare_handler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import pathlib
5 | from typing import Any, Sequence
6 | import torch
7 |
8 | import numpy as np
9 | import matplotlib
10 | import matplotlib.pyplot as plt
11 |
12 | from torch.utils.data import DataLoader
13 |
14 | from .base import ArtifactHandler
15 |
16 | # Use non-GUI backend
17 | matplotlib.use('Agg')
18 |
19 |
20 | class ObscCompareHandler(ArtifactHandler):
21 | """ Handler class to plot predicted vs. actual obscuration """
22 |
23 | def __init__(self):
24 | super().__init__()
25 |
26 | self.y_pred_l = []
27 | self.y_true_l = []
28 |
29 | def start(self, model: Any, dataloader: DataLoader) -> None:
30 | pass
31 |
32 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None:
33 | _, features = dl_iter
34 | features = features.cpu().detach().numpy()
35 |
36 | # Get prediction from model
37 | _, pred = model_out
38 | pred = torch.sigmoid(pred[..., -1])
39 | pred = pred.cpu().detach().numpy()
40 |
41 | # Get predicted label as argmax
42 | self.y_pred_l.append(pred)
43 | self.y_true_l.append(features[..., -1])
44 |
45 | def save(self, output_dir) -> pathlib.Path:
46 | y_pred = np.concatenate(self.y_pred_l) * 100.
47 | y_true = np.concatenate(self.y_true_l) * 100.
48 |
49 | # Create the plot!
50 | fig, ax = plt.subplots(figsize=(12, 12))
51 | ax.set_aspect('equal')
52 | ax.scatter(y_pred, y_true, 9)
53 | ax.set_xlabel(r'Predicted Obscuration [%]')
54 | ax.set_ylabel(r'Est. Obscuration by Logit Regression [%]')
55 | ax.set_title('Model Obscuration Prediction Accuracy')
56 | ax.grid()
57 | ax.plot((0, 100), (0, 100), '--k', linewidth=2)
58 | ax.legend(['Model Data', 'y = x'])
59 |
60 | output_path = output_dir / 'obsc_compare_plot.png'
61 | fig.savefig(str(output_path))
62 | plt.close(fig)
63 |
64 | return output_path
65 |
--------------------------------------------------------------------------------
/rsc/artifacts/samples_handler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | import pathlib
4 | from typing import Any, Sequence
5 |
6 | import numpy as np
7 | import matplotlib
8 | import matplotlib.pyplot as plt
9 | import torch
10 | from torch.utils.data import DataLoader
11 |
12 | from .base import ArtifactHandler
13 |
14 | # Use non-GUI backend
15 | matplotlib.use('Agg')
16 |
17 |
18 | class SamplesHandler(ArtifactHandler):
19 |
20 | def __init__(self, max_per_category: int = 50):
21 | super().__init__()
22 | self.max_per_category = max_per_category
23 | self.labels = []
24 | self.samples = {}
25 | self.maxed_out = {}
26 |
27 | def start(self, model: Any, dataloader: DataLoader) -> None:
28 | # Try to get labels from model
29 | self.labels = model.__dict__.get('labels')
30 | assert self.labels is not None
31 |
32 | # Setup samples dictionary
33 | for correct in (True, False):
34 | self.samples[correct] = {}
35 | self.maxed_out[correct] = {}
36 | for label in self.labels:
37 | self.samples[correct][label] = []
38 | self.maxed_out[correct][label] = False
39 |
40 | def on_iter(self, dl_iter: Sequence, model_out: Sequence) -> None:
41 |
42 | # Skip all this extra processing if we reached our maximum sample count anyway
43 | for label in self.labels:
44 | if any(not self.maxed_out[correct][label]
45 | for correct in (True, False)):
46 | break
47 | else:
48 | return # we did not break, so we must be maxed out
49 |
50 | x, features = dl_iter
51 |
52 | # Get mask and image
53 | xpm = x[:, (-1,), :, :]
54 | xm = x[:, (-2,), :, :]
55 | x = x[:, :-2, :, :]
56 |
57 | # Get true label
58 | y_true = features.cpu().detach().numpy()
59 |
60 | # Transform image and mask
61 | x_p = (np.moveaxis(x.numpy(), 1, -1) * 255.).astype(
62 | np.uint8)
63 | xm_p = np.moveaxis(xm.numpy(), 1, -1)
64 | xpm_p = np.moveaxis(xpm.numpy(), 1, -1)
65 |
66 | # Parse model prediction
67 | m, y_pred = model_out
68 | m_p = np.moveaxis(m.cpu().detach().numpy(), 1, -1)
69 |
70 | # True label (argmax)
71 | y_true_am: np.ndarray = np.argmax(y_true[:, 0:-1], 1)
72 |
73 | # Predicted label (argmax)
74 | y_pred_am = torch.argmax(y_pred[:, :-1], 1)
75 | y_pred_am = y_pred_am.cpu().detach().numpy() # type: ignore
76 |
77 | # Predicted obscuration (sigmoid)
78 | y_pred_obsc = torch.sigmoid(y_pred[:, -1]).cpu().detach().numpy() # type: ignore
79 |
80 | for i in range(x.shape[0]):
81 |
82 | # Are we correct?
83 | correct = y_true_am[i] == y_pred_am[i]
84 |
85 | # Get current true label
86 | this_label = self.labels[y_true_am[i]]
87 |
88 | # Skip if above limit for category
89 | if len(self.samples[correct][this_label]) > self.max_per_category:
90 | self.maxed_out[correct][this_label] = True
91 | continue
92 |
93 | # If we got here, then take a sample
94 | self.samples[correct][this_label].append((
95 | x_p[i, ...],
96 | xm_p[i, ...],
97 | xpm_p[i, ...],
98 | m_p[i, ...],
99 | y_true[i, ...],
100 | y_true_am[i, ...],
101 | y_pred_am[i, ...],
102 | y_pred_obsc[i, ...],
103 | ))
104 |
105 | def save(self, output_dir: pathlib.Path) -> pathlib.Path:
106 |
107 | # Create directory for samples
108 | samples_dir = output_dir / 'samples'
109 | samples_dir.mkdir(parents=False, exist_ok=True)
110 |
111 | for correct in (True, False):
112 | # Create directory for "correctness"
113 | correct_dir = samples_dir / ('correct' if correct else 'incorrect')
114 | correct_dir.mkdir(parents=False, exist_ok=True)
115 |
116 | for label in self.labels:
117 | # Create directory for label
118 | label_dir = correct_dir / str(label)
119 | label_dir.mkdir(parents=False, exist_ok=True)
120 |
121 | # Loop through all the samples
122 | for idx, (x_p, xm_p, xpm_p, m_p, \
123 | y_true, y_true_am, y_pred_am, \
124 | y_pred_obsc) in enumerate(self.samples[correct][label]):
125 |
126 | # Get num-channels (3 - RGB, 4 - RGB + NIR)
127 | _, _, nc = x_p.shape
128 |
129 | # Create figure
130 | fig, ax = plt.subplots(1,
131 | nc + 1,
132 | sharex=True,
133 | sharey=True,
134 | figsize=(15, 3.5))
135 | for _ax in ax:
136 | _ax.xaxis.set_visible(False)
137 | _ax.yaxis.set_visible(False)
138 |
139 | # Allows us to support different
140 | # number of subplots in NIR case
141 | ax_idx = 0
142 |
143 | # Plot the result, highlighting the road with the mask
144 | ax[ax_idx].set_title(
145 | 'RGB\nTrue: %s; Pred: %s' %
146 | (self.labels[y_true_am], self.labels[y_pred_am]))
147 | ax[ax_idx].imshow(np.uint8(x_p[..., (0, 1, 2)]))
148 | ax_idx += 1
149 |
150 | # Plot color IR if relevant
151 | if nc > 3:
152 | ax[ax_idx].set_title('Color IR')
153 | ax[ax_idx].imshow(np.uint8(x_p[..., (3, 0, 1)]))
154 | ax_idx += 1
155 |
156 | ax[ax_idx].set_title('Combined Image + Mask\nObsc: %.1f%%; Pred: %.1f%%' % (y_true[-1] * 100, y_pred_obsc * 100))
157 | ax[ax_idx].imshow(
158 | np.uint8(x_p[..., (0, 1, 2)] * (0.33 + 0.67 * xm_p)))
159 | ax_idx += 1
160 |
161 | ax[ax_idx].set_title('Combined Image \n+ True ProbMask')
162 | ax[ax_idx].imshow(np.uint8(0.5 * x_p[..., (0, 1, 2)]))
163 | ax[ax_idx].imshow(xpm_p,
164 | cmap='magma',
165 | vmin=0,
166 | vmax=1,
167 | alpha=0.33)
168 | ax_idx += 1
169 |
170 | ax[ax_idx].set_title('Combined Image \n+ Pred ProbMask')
171 | ax[ax_idx].imshow(np.uint8(0.5 * x_p[..., (0, 1, 2)]))
172 | ax[ax_idx].imshow(m_p, cmap='magma', vmin=0, vmax=1, alpha=0.33)
173 |
174 | # Save the figure
175 | output_path = label_dir / f'sample_{idx:05d}.png'
176 | fig.savefig(str(output_path))
177 | plt.close(fig)
178 |
179 | return samples_dir
--------------------------------------------------------------------------------
/rsc/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/rsc/common/__init__.py
--------------------------------------------------------------------------------
/rsc/common/aws_naip.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """ Some helpful functions to download NAIP data on AWS """
4 |
5 | import pathlib
6 | from typing import List
7 |
8 | import boto3
9 |
10 | BASE_PATH = pathlib.Path('/data/road_surface_classifier')
11 | assert BASE_PATH.is_dir()
12 | AWS_PATH = BASE_PATH / 'naip_on_aws'
13 | AWS_PATH.mkdir(exist_ok=True)
14 |
15 |
16 | def naip_s3_fetch(bucket_name: str, object_name: str, output: pathlib.Path):
17 | """
18 | Fetch an S3 object from a bucket, ensuring that RequestPayer=requester is set
19 |
20 | Args:
21 | bucket_name (str): AWS S3 bucket name
22 | object_name (str): Object name in bucket
23 | output (pathlib.Path): Path to save the output file
24 |
25 | Raises:
26 | FileExistsError: If the output file already exists
27 | """
28 | # Prevent overwriting the output file
29 | if output.exists():
30 | raise FileExistsError('File already exists: %s' % str(output))
31 |
32 | # Use the boto3 s3 API to fetch the file object
33 | s3_client = boto3.client('s3')
34 | try:
35 | with open(output, 'wb') as f:
36 | s3_client.download_fileobj(bucket_name,
37 | object_name,
38 | f,
39 | ExtraArgs={'RequestPayer': 'requester'})
40 | except:
41 | # If cancelled or something bad happens,
42 | # unlink the file we are trying to download and raise
43 | output.unlink()
44 | raise
45 |
46 |
47 | def get_naip_manifest(bucket_name='naip-analytic') -> List[str]:
48 | """
49 | Read the bucket manifest for a given NAIP on AWS bucket. If the manifest
50 | doesn't yet exist, then fetch it from S3.
51 |
52 | Args:
53 | bucket_name (str, optional): NAIP on AWS bucket name. Defaults to 'naip-analytic'.
54 |
55 | Returns:
56 | List[str]: List of objects in manifest
57 | """
58 | manifest_path = AWS_PATH / 'manifest.txt'
59 |
60 | # Fetch the manifest if it doesn't exist.
61 | if not manifest_path.exists():
62 | naip_s3_fetch(bucket_name, 'manifest.txt', manifest_path)
63 |
64 | # Read it in!
65 | with open(manifest_path, 'r') as f:
66 | return [e.strip() for e in f.readlines()]
67 |
68 |
69 | def get_naip_file(object_name: str,
70 | bucket_name='naip-analytic') -> pathlib.Path:
71 | """
72 | Get a NAIP file by object name. If it already exists, this will not download again.
73 |
74 | Args:
75 | object_name (str): Object name
76 | bucket_name (str, optional): NAIP on AWS Bucket name. Defaults to 'naip-analytic'.
77 |
78 | Returns:
79 | pathlib.Path: Path to the downloaded file.
80 | """
81 | # Get output path, return immediately if exists
82 | output_path = AWS_PATH / object_name
83 | if output_path.exists():
84 | return output_path
85 |
86 | # Otherwise, create the parent directory, and fetch the file
87 | output_path.parent.mkdir(parents=True, exist_ok=True)
88 | naip_s3_fetch(bucket_name, object_name, output_path)
89 |
90 | return output_path
91 |
--------------------------------------------------------------------------------
/rsc/common/geometric_median.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """ Python implementation of geometric median as described in
4 | Yehuda Vardi and Cun-Hui Zhang's paper
5 | "The multivariate L1-median and associated data depth"
6 | """
7 | from typing import Optional
8 | import numpy as np
9 | from scipy.spatial.distance import cdist, euclidean
10 |
11 |
12 | def geometric_median(X: np.ndarray, eps: float = 1e-5) -> np.ndarray:
13 | """
14 | Python implementation of geometric median as described in
15 | Yehuda Vardi and Cun-Hui Zhang's paper
16 | "The multivariate L1-median and associated data depth"
17 |
18 | Taken from: https://stackoverflow.com/questions/30299267/geometric-median-of-multidimensional-points
19 |
20 | Released under zlib license.
21 |
22 | Args:
23 | X (np.ndarray): Input N-dimensional data
24 | eps (float, optional): Convergence tolerance. Defaults to 1e-5.
25 |
26 | Returns:
27 | np.ndarray: Geometric median value
28 | """
29 | y = np.mean(X, 0)
30 |
31 | while True:
32 | D = cdist(X, [y])
33 | nonzeros = (D != 0)[:, 0]
34 |
35 | Dinv = 1 / D[nonzeros]
36 | Dinvs = np.sum(Dinv)
37 | W = Dinv / Dinvs
38 | T = np.sum(W * X[nonzeros], 0)
39 |
40 | num_zeros = len(X) - np.sum(nonzeros)
41 | if num_zeros == 0:
42 | y1 = T
43 | elif num_zeros == len(X):
44 | return y
45 | else:
46 | R = (T - y) * Dinvs
47 | r = np.linalg.norm(R)
48 | rinv = 0 if r == 0 else num_zeros / r
49 | y1 = max(0, 1 - rinv) * T + min(1, rinv) * y
50 |
51 | if euclidean(y, y1) < eps:
52 | return y1
53 |
54 | y = y1
--------------------------------------------------------------------------------
/rsc/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/rsc/inference/__init__.py
--------------------------------------------------------------------------------
/rsc/inference/fetch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | import numpy as np
4 | from osgeo import ogr, osr
5 | import PIL.Image
6 | import PIL.ImageDraw
7 |
8 | from rsc.common.utils import imread_geotransform, imread_srs, imread, map_to_pix
9 |
10 | ogr.UseExceptions()
11 |
12 | # Get ESPG:4326 reference SRS
13 | srs_ref = osr.SpatialReference()
14 | srs_ref.ImportFromEPSG(4326)
15 | srs_ref.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
16 |
17 |
18 | def fetch(img_path, x1, y1, x2, y2, wkt):
19 |
20 | # Fetch the tile we need
21 | h, w = y2 - y1, x2 - x1
22 | im = imread(str(img_path), x_off=x1, y_off=y1, w=w, h=h)
23 | xform = imread_geotransform(str(img_path), x_off=x1, y_off=y1)
24 |
25 | # Get the spatial reference (they change across the tiles)
26 | srs = osr.SpatialReference()
27 | srs.ImportFromProj4(imread_srs(str(img_path)))
28 | srs.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
29 | im_trans = osr.CoordinateTransformation(srs_ref, srs)
30 |
31 | # Get list of points from WKT
32 | geom: ogr.Geometry = ogr.CreateGeometryFromWkt(wkt)
33 | geom.Transform(im_trans)
34 | pts = np.array(
35 | [geom.GetPoint_2D(idx) for idx in range(geom.GetPointCount())])
36 |
37 | # Convert to image-space x, y
38 | ix, iy = map_to_pix(list(xform), pts[:, 0], pts[:, 1])
39 |
40 | # Create a new image of the same shape, and draw a line
41 | # to create a mask
42 | mask_pil = PIL.Image.new('L', ((w, h)), color=0)
43 | d = PIL.ImageDraw.Draw(mask_pil)
44 | d.line(
45 | [(x, y) for x, y in zip(ix, iy)], # type: ignore
46 | fill=255,
47 | width=2,
48 | joint="curve")
49 | mask = np.array(mask_pil)[:, :, np.newaxis]
50 |
51 | return np.concatenate((im, mask), axis=-1)
--------------------------------------------------------------------------------
/rsc/inference/mass_inference_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | import sqlite3
4 | import pathlib
5 | from typing import Any
6 | import pandas as pd
7 | import numpy as np
8 |
9 | import torch
10 | from torch.utils.data import Dataset
11 |
12 | import torch.multiprocessing
13 |
14 | from .fetch import fetch
15 |
16 | torch.multiprocessing.set_sharing_strategy('file_system')
17 |
18 | IMAGERY_PATH = pathlib.Path('/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019')
19 | assert IMAGERY_PATH.is_dir()
20 |
21 |
22 | class MassInferenceDataset(Dataset):
23 | def __init__(self,
24 | sqlite_path: pathlib.Path,
25 | transform,
26 | n_channels=4,
27 | limit=-1):
28 |
29 | # Connect to sqlite database and get all truthed roadways with no obscuration that are valid
30 | with sqlite3.connect('file:%s?mode=ro' % str(sqlite_path.resolve()),
31 | uri=True) as con:
32 | self.df = pd.read_sql('SELECT * FROM features;',
33 | con).set_index('osm_id')
34 |
35 | # Load dataframe, interpret length
36 | self.n_idxs = len(self.df) if limit == -1 else min(limit, len(self.df))
37 |
38 | # Number of channels in the image
39 | self.n_channels = n_channels
40 |
41 | # Transformation object
42 | self.transform = transform
43 |
44 | def __len__(self):
45 | return self.n_idxs
46 |
47 | def __getitem__(self, idx):
48 |
49 | # Get row
50 | row: Any = self.df.iloc[idx] # type: ignore
51 |
52 | x1, y1, x2, y2 = [row[e].item() for e in ('x1', 'y1', 'x2', 'y2')]
53 |
54 | # Mask
55 | im = fetch(IMAGERY_PATH / row['img'], x1, y1, x2, y2, row['wkt'])
56 |
57 | # Concat image and masks for output
58 | x = self.transform(im)
59 |
60 | return row.name, x
61 |
--------------------------------------------------------------------------------
/rsc/osm/README.md:
--------------------------------------------------------------------------------
1 | # OSM Helper Module
2 |
3 | Helper module to aid with working with OpenStreetMap data, with a focus on drivable road networks.
4 |
5 | ## License
6 | [MIT](https://choosealicense.com/licenses/mit/) © 2024 Jonathan Dalrymple
--------------------------------------------------------------------------------
/rsc/osm/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import sys
5 | import logging
6 | import functools
7 |
8 | # Configure logger
9 | _logger = logging.getLogger(__name__)
10 | _logger.setLevel(logging.DEBUG) # configure log level here
11 | _logger_handlers = [logging.StreamHandler(sys.stdout)]
12 | _logger_formatter = logging.Formatter(
13 | r'%(asctime)-15s %(levelname)s [%(module)s] %(message)s')
14 | _logger.handlers.clear()
15 | for h in _logger_handlers:
16 | h.setFormatter(_logger_formatter)
17 | _logger.addHandler(h)
18 |
19 | def get_logger():
20 | """ Fetch the module logger """
21 | return _logger
22 |
23 | def gdal_required(is_available):
24 | """ Decorator function to check if GDAL was imported """
25 | def _decorator(func):
26 | @functools.wraps(func)
27 | def wrapper(*args, **kwargs):
28 | if not is_available:
29 | raise RuntimeError('OGR is required for this function / method.')
30 | return func(*args, **kwargs)
31 | return wrapper
32 | return _decorator
33 |
34 | # Exposed classes to user
35 | from .osm_element import OSMElement, OSMNode, OSMWay
36 | from .osm_element_factory import OSMElementFactory
37 | from .osm_network import OSMNetwork
38 | from . import overpass_api
--------------------------------------------------------------------------------
/rsc/osm/osm_element.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """ Overpass API Query Implementation """
4 |
5 | from __future__ import annotations
6 |
7 | import json
8 | from typing import List, Union, TypeVar, Type
9 | import warnings
10 | from abc import ABC, abstractmethod
11 | import xml.etree.ElementTree as ET
12 |
13 | _gdal_available = False
14 | try:
15 | from osgeo import ogr
16 | ogr.UseExceptions()
17 | _gdal_available = True
18 | except ModuleNotFoundError:
19 | warnings.warn('Could not find GDAL. Some methods may not be available.', ImportWarning)
20 |
21 | from . import gdal_required
22 |
23 | # Generic type for abstract classmethods of OSMElement
24 | T = TypeVar('T', bound='OSMElement')
25 |
26 | class OSMElement(ABC):
27 | """ Abstract class to describe an OSM element (node, way, etc.) """
28 | __slots__ = ['id', 'tags']
29 |
30 | TYPE = '' # holds OSM type string ('node', 'way', etc.)
31 |
32 | def __init__(self, **kwargs):
33 | """ Default constructor: sets attributes from kwargs """
34 | # Set id and tags (universal to every OSM element)
35 | # NOTE: it's okay if tags is empty, which is why it
36 | # doesn't get the fancy attribute wrapper
37 | self._set_attr_from_kwargs(kwargs, 'id', int, default=None)
38 | self.tags = kwargs.get('tags', {})
39 |
40 | def _set_attr_from_kwargs(self,
41 | kwargs: dict,
42 | attr_key: str,
43 | dt: type,
44 | default=None):
45 | """ Special way to set an attribute that throws an error if the default is used """
46 | try:
47 | self.__setattr__(attr_key, dt(kwargs[attr_key]))
48 | except (TypeError, ValueError):
49 | warnings.warn(
50 | f'Cannot cast input for \'{attr_key}\' to {dt}! Setting to {repr(default)}',
51 | RuntimeWarning)
52 | self.__setattr__(attr_key, default)
53 | except KeyError:
54 | warnings.warn(
55 | f'No {attr_key} specified in constructor! Setting to {repr(default)}',
56 | RuntimeWarning)
57 | self.__setattr__(attr_key, default)
58 |
59 | @staticmethod
60 | def _parse_xml_osm_tag(t: ET.Element) -> dict:
61 | """ Parse an XML OSM tag element into a dictionary """
62 | assert t.tag == 'tag'
63 | d = t.attrib
64 | if 'k' not in d or 'v' not in d:
65 | warnings.warn(f'Tag element empty! {d}', RuntimeWarning)
66 | return {}
67 | return {d['k']: d['v']}
68 |
69 | @staticmethod
70 | def _create_xml_osm_tag(t: dict) -> List[ET.Element]:
71 | """ Create a list of ETree Element objects from a tags dictionary """
72 | el_list = [] # type: List[ET.Element]
73 | for k, v in t.items():
74 | el = ET.Element('tag')
75 | el.attrib['k'] = str(k)
76 | el.attrib['v'] = str(v)
77 | el_list.append(el)
78 | return el_list
79 |
80 | def __str__(self):
81 | return f'<{self.__class__.__name__}; id: {self.id}; tags: {len(self.tags)}>'
82 |
83 | def __repr__(self):
84 | return self.__str__()
85 |
86 | @abstractmethod
87 | def to_json_dict(self) -> dict:
88 | """ Export this element to a JSON dict (Python dictionary) """
89 | return {'type': self.TYPE, 'id': self.id}
90 |
91 | @abstractmethod
92 | def to_xml(self) -> ET.Element:
93 | """ Export this element to an XML string """
94 | pass
95 |
96 | def to_xml_str(self) -> str:
97 | """ Export this element to an XML string """
98 | return ET.tostring(self.to_xml(), encoding='utf-8').decode('utf-8')
99 |
100 | @classmethod
101 | @abstractmethod
102 | def from_json_dict(cls: Type[T], json_dict: dict) -> T:
103 | """ Create this element from a JSON dict (Python dictionary) """
104 | pass
105 |
106 | def to_json(self) -> str:
107 | """ Export this element to an JSON string """
108 | return json.dumps(self.to_json_dict())
109 |
110 | @classmethod
111 | def from_json(cls: Type[T], json_str: str) -> T:
112 | """ Create this element from a JSON string """
113 | return cls.from_json_dict(json.loads(json_str))
114 |
115 | @classmethod
116 | @abstractmethod
117 | def from_xml(cls: Type[T], tree: ET.Element) -> T:
118 | """ Create this element from an XML Element """
119 | pass
120 |
121 | @classmethod
122 | def from_xml_str(cls: Type[T], xml_str: str) -> T:
123 | """ Create this element from an XML string """
124 | return cls.from_xml(ET.fromstring(xml_str))
125 |
126 |
127 | class OSMNode(OSMElement):
128 | """ Class to describe an OSM node object """
129 | __slots__ = ['lat', 'lon']
130 |
131 | TYPE = 'node'
132 |
133 | def __init__(self, **kwargs):
134 | super().__init__(**kwargs)
135 |
136 | # Set latitude and longitude
137 | self._set_attr_from_kwargs(kwargs, 'lat', float, default=0.0)
138 | self._set_attr_from_kwargs(kwargs, 'lon', float, default=0.0)
139 |
140 | def __str__(self):
141 | s = super().__str__()[:-1]
142 | s += f'; lat: {self.lat:.3f}; lon: {self.lon:.3f}>'
143 | return s
144 |
145 | @gdal_required(_gdal_available)
146 | def to_ogr_geom(self) -> ogr.Geometry:
147 | """
148 | Convert this node to an OGR geometry object
149 |
150 | Note:
151 | Requires GDAL / OGR
152 |
153 | Returns:
154 | ogr.Geometry: OGR Geometry for node.
155 | """
156 | pt = ogr.Geometry(ogr.wkbPoint)
157 | pt.AddPoint_2D(self.lon, self.lat)
158 | return pt
159 |
160 | @gdal_required(_gdal_available)
161 | def to_wkt(self) -> str:
162 | """
163 | Get the WKT representation for this node.
164 |
165 | Note:
166 | Requires GDAL / OGR
167 |
168 | Returns:
169 | str: WKT representation for point
170 | """
171 | return self.to_ogr_geom().ExportToWkt()
172 |
173 | def to_json_dict(self) -> dict:
174 | """ Export this element to a JSON dict (Python dictionary) """
175 | d = super().to_json_dict()
176 | d['lat'] = float(self.lat)
177 | d['lon'] = float(self.lon)
178 | if len(self.tags):
179 | d['tags'] = {}
180 | for k, v in self.tags.items():
181 | d['tags'][k] = str(v)
182 | return d
183 |
184 | def to_xml(self) -> ET.Element:
185 | """ Export this element to an XML string """
186 | # Create OSM XML element
187 | el = ET.Element('node')
188 |
189 | # Add id if it exists
190 | if self.id is not None:
191 | el.attrib['id'] = str(self.id)
192 |
193 | # Add any tags
194 | for tag_element in self._create_xml_osm_tag(self.tags):
195 | el.append(tag_element)
196 |
197 | # Add any remaining attributes
198 | el.attrib['lat'] = str(self.lat)
199 | el.attrib['lon'] = str(self.lon)
200 |
201 | # Return as XML Element
202 | return el
203 |
204 | @classmethod
205 | def from_json_dict(cls: Type[T], json_dict: dict) -> T:
206 | """ Create this element from a JSON dict (Python dictionary) """
207 | # Trick since the OSM keys match the attribute names exactly
208 | # i.e. 'id', 'lat', 'lon'
209 | # If this wasn't the case extra processing would be necessary
210 | return cls(**json_dict)
211 |
212 | @classmethod
213 | def from_xml(cls: Type[T], tree: ET.Element) -> T:
214 | """ Create this element from an XML string """
215 | # Assert OSM element type
216 | assert tree.tag == cls.TYPE
217 |
218 | # Trick since the OSM keys match the attribute names exactly
219 | # i.e. 'id', 'lat', 'lon'
220 | # If this wasn't the case extra processing would be necessary
221 | kwargs = tree.attrib.copy()
222 |
223 | # Parse any tags
224 | tags = {}
225 | for t in tree.findall('tag'):
226 | tags.update(cls._parse_xml_osm_tag(t))
227 | kwargs['tags'] = tags # type: ignore
228 |
229 | return cls(**kwargs)
230 |
231 |
232 | class OSMWay(OSMElement):
233 | """ Class to describe an OSM way object """
234 | __slots__ = ['nodes']
235 |
236 | TYPE = 'way'
237 |
238 | def __init__(self, **kwargs):
239 | super().__init__(**kwargs)
240 |
241 | # Set nodes
242 | self._set_attr_from_kwargs(kwargs, 'nodes', list, default=[])
243 |
244 | def __len__(self):
245 | return self.nodes.__len__()
246 |
247 | def __str__(self):
248 | s = super().__str__()[:-1]
249 | s += f'; nodes: {self.nodes.__len__()}>'
250 | return s
251 |
252 | @staticmethod
253 | def _parse_xml_osm_node(n: ET.Element) -> Union[int, None]:
254 | """ Parse an XML OSM tag element into a dictionary """
255 | assert n.tag == 'nd'
256 | d = n.attrib
257 | try:
258 | return int(d['ref'])
259 | except (TypeError, ValueError):
260 | warnings.warn(
261 | f'Node element id cannit be casted to int! Setting to None. {d["ref"]}',
262 | RuntimeWarning)
263 | return None
264 | except KeyError:
265 | warnings.warn(f'Node element empty! {d}', RuntimeWarning)
266 | return None
267 |
268 | @staticmethod
269 | def _create_xml_osm_node(n: int) -> ET.Element:
270 | """ Create a list of ETree Element objects from a node id """
271 | el = ET.Element('nd')
272 | el.attrib['ref'] = str(n)
273 | return el
274 |
275 | def to_json_dict(self) -> dict:
276 | """ Export this element to a JSON dict (Python dictionary) """
277 | d = super().to_json_dict()
278 | d['nodes'] = list(self.nodes)
279 | if len(self.tags):
280 | d['tags'] = {}
281 | for k, v in self.tags.items():
282 | d['tags'][k] = str(v)
283 | return d
284 |
285 | def to_xml(self) -> ET.Element:
286 | """ Export this element to an XML Element """
287 | # Create OSM XML element
288 | el = ET.Element('way')
289 |
290 | # Add id if it exists
291 | if self.id is not None:
292 | el.attrib['id'] = str(self.id)
293 |
294 | # Add any nodes
295 | for node in self.nodes: # type: ignore
296 | el.append(self._create_xml_osm_node(node))
297 |
298 | # Add any tags
299 | for tag_element in self._create_xml_osm_tag(self.tags):
300 | el.append(tag_element)
301 |
302 | # Return as an element
303 | return el
304 |
305 | @classmethod
306 | def from_json_dict(cls: Type[T], json_dict: dict) -> T:
307 | """ Create this element from a JSON dict (Python dictionary) """
308 | # Trick since the OSM keys match the attribute names exactly
309 | # If this wasn't the case extra processing would be necessary
310 | return cls(**json_dict)
311 |
312 | @classmethod
313 | def from_xml(cls, tree: ET.Element) -> OSMWay:
314 | """ Create this element from an XML Element """
315 | # Assert OSM element type
316 | assert tree.tag == cls.TYPE
317 |
318 | # Kwargs to pass to class constructor
319 | kwargs = {}
320 |
321 | # Parse id
322 | if 'id' in tree.attrib:
323 | kwargs['id'] = tree.attrib['id']
324 |
325 | # Parse any tags
326 | nodes = []
327 | for nd in tree.findall('nd'):
328 | node_id = cls._parse_xml_osm_node(nd)
329 | if node_id is not None:
330 | nodes.append(node_id)
331 | kwargs['nodes'] = nodes # type: ignore
332 |
333 | # Parse any tags
334 | tags = {}
335 | for t in tree.findall('tag'):
336 | tags.update(cls._parse_xml_osm_tag(t))
337 | kwargs['tags'] = tags # type: ignore
338 |
339 | return cls(**kwargs)
340 |
--------------------------------------------------------------------------------
/rsc/osm/osm_element_factory.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import xml.etree.ElementTree as ET
5 |
6 | from .osm_element import OSMElement, OSMNode, OSMWay
7 |
8 |
9 | class OSMElementFactory:
10 | """ Factory class for OSM elements """
11 |
12 | @staticmethod
13 | def from_json_dict(d: dict) -> OSMElement:
14 | # Read OSM 'type' key
15 | osm_type = d.get('type')
16 |
17 | # Create class
18 | if osm_type == 'node':
19 | return OSMNode.from_json_dict(d)
20 | elif osm_type == 'way':
21 | return OSMWay.from_json_dict(d)
22 | else:
23 | raise ValueError('Unknown OSM type: %s' % osm_type)
24 |
25 | @staticmethod
26 | def from_xml(el: ET.Element) -> OSMElement:
27 | # Read OSM 'type' key
28 | osm_type = el.tag
29 |
30 | # Create class
31 | if osm_type == 'node':
32 | return OSMNode.from_xml(el)
33 | elif osm_type == 'way':
34 | return OSMWay.from_xml(el)
35 | else:
36 | raise ValueError('Unknown XML tag: %s' % osm_type)
37 |
--------------------------------------------------------------------------------
/rsc/osm/osm_overpass_api.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """ Classes to aid in OSM Overpass API Queries and Results """
4 |
5 | from __future__ import annotations
6 |
7 | import json
8 | import pathlib
9 | import textwrap
10 | from abc import ABC, abstractmethod
11 | from typing import Iterable, Union
12 | import requests
13 |
14 | from osgeo import ogr
15 |
16 |
17 | class OSMOverpassQuery(ABC):
18 | """ Abstract class to perform an OSM Overpass API query """
19 |
20 | __slots__ = ['_endpoint', '_timeout', '_format', '_poly']
21 |
22 | def __init__(self, **kwargs):
23 | self._endpoint = kwargs.get(
24 | 'endpoint', 'https://lz4.overpass-api.de/api/interpreter')
25 | self._timeout = kwargs.get('timeout', 180)
26 | self._format = kwargs.get('format', 'json')
27 | self._poly = None
28 |
29 | def set_poly_from_list(self, poly_list: Iterable[tuple[float,
30 | float]]) -> None:
31 | """
32 | Set the polygon query boundary from a list of coordinates
33 |
34 | Args:
35 | poly_list (Iterable[tuple[float, float]]): List of floats in (lon, lat) format
36 | """
37 | self._poly = []
38 | for lon, lat in poly_list:
39 | self._poly.append((lon, lat))
40 |
41 | def set_poly_from_bbox(self, lat0: float, lon0: float, lat1: float,
42 | lon1: float) -> None:
43 | """
44 | Set Overpass Query Filter polygon by bounding box
45 |
46 | Args:
47 | lat0 (float): First latitude (can be min or max)
48 | lon0 (float): First longitude (can be min or max)
49 | lat1 (float): Second latitude (can be min or max)
50 | lon1 (float): Second longitude (can be min or max)
51 | """
52 | # Find max, min latitude and longitude
53 | lat_max = max((lat0, lat1))
54 | lat_min = min((lat0, lat1))
55 | lon_max = max((lon0, lon1))
56 | lon_min = min((lon0, lon1))
57 |
58 | # Set polygon
59 | self._poly = [(lon_min, lat_max), (lon_max, lat_max),
60 | (lon_max, lat_min), (lon_min, lat_min),
61 | (lon_min, lat_max)]
62 |
63 | def set_poly_from_wkt(self, wkt_str: str) -> None:
64 | """
65 | Set Overpass Query Filter polygon by WKT string
66 |
67 | Args:
68 | wkt_str (str): Input polygon in WKT format
69 |
70 | Raises:
71 | ValueError: If the input WKT string is not of a polygon
72 | """
73 | # Load WKT string into JSON dict
74 | geojson_dict = json.loads(
75 | ogr.CreateGeometryFromWkt(wkt_str).ExportToJson())
76 |
77 | # Sanity check geometry type
78 | if geojson_dict['type'] != 'Polygon':
79 | raise ValueError('Input WKT string must be a polygon! Got: %s' %
80 | geojson_dict['type'])
81 |
82 | # Set polygon (0 index is for linearRing inside polygon)
83 | self._poly = geojson_dict['coordinates'][0]
84 |
85 | @property
86 | def _poly_query_str(self) -> str:
87 | if self._poly is None:
88 | raise ValueError('Polygon not set!')
89 | return ' '.join([
90 | ' '.join((f'{lat:.6f}', f'{lon:.6f}')) for lon, lat in self._poly
91 | ])
92 |
93 | @property
94 | @abstractmethod
95 | def _query_str(self) -> str:
96 | return ''
97 |
98 | @property
99 | def endpoint(self) -> str:
100 | return self._endpoint
101 |
102 | def set_endpoint(self, endpoint: str) -> None:
103 | """ Set the Overpass API Endpoint """
104 | self._endpoint = endpoint
105 |
106 | def _perform_query(self) -> requests.models.Response:
107 | """ Perform an OSM Overpass API Request! """
108 | query_str = textwrap.dedent(self._query_str).replace('\n', '')
109 | result = requests.get(self.endpoint, params={'data': query_str})
110 | result.raise_for_status()
111 | return result
112 |
113 | def perform_query(self) -> OSMOverpassResult:
114 | """ Perform an OSM Overpass API Request! """
115 | return OSMOverpassResult(self._perform_query())
116 |
117 |
118 | class OSMOverpassResult:
119 | """ Container class for OSM Overpass result data. Will need to be subclassed to be useful. """
120 | __slots__ = ['_result']
121 |
122 | def __init__(self, result: requests.models.Response):
123 | self._result = result
124 |
125 | def to_file(self, output_path: Union[pathlib.Path, str]) -> None:
126 | """
127 | Have the query result to file. The file must be of the same format
128 | as the query result.
129 |
130 | Args:
131 | output_path (Union[pathlib.Path, str]): Output file path
132 |
133 | Raises:
134 | ValueError: if format is not recognized
135 | """
136 | # Get output path, sanity check dir
137 | output_path = pathlib.Path(output_path)
138 | assert output_path.parent.exists()
139 |
140 | # Assert output path suffix is valid
141 | this_format = self.format
142 | if 'json' in this_format:
143 | assert output_path.suffix.lower() == '.json'
144 | elif 'xml' in this_format:
145 | assert output_path.suffix.lower() in ('.xml', '.osm')
146 | else:
147 | raise ValueError('Unknown format: %s. Cannot export to file.' %
148 | str(this_format))
149 |
150 | # Save to file!
151 | with open(output_path, 'wb') as f:
152 | f.write(self._result.content)
153 |
154 | @property
155 | def format(self) -> str:
156 | if 'Content-Type' in self._result.headers:
157 | return self._result.headers['Content-Type'].split('/')[-1]
158 | raise ValueError(
159 | 'Unknown format! Content-Type not found in result header.')
160 |
--------------------------------------------------------------------------------
/rsc/osm/overpass_api/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | from .osm_overpass_api import OSMOverpassQuery, OSMOverpassResult
--------------------------------------------------------------------------------
/rsc/osm/overpass_api/osm_overpass_api.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """ Classes to aid in OSM Overpass API Queries and Results """
4 |
5 | from __future__ import annotations
6 |
7 | import json
8 | import pathlib
9 | import textwrap
10 | from abc import ABC, abstractmethod
11 | from typing import Iterable, Union
12 | import requests
13 |
14 | from osgeo import ogr
15 |
16 |
17 | class OSMOverpassQuery(ABC):
18 | """ Abstract class to perform an OSM Overpass API query """
19 |
20 | __slots__ = ['_endpoint', '_timeout', '_format', '_poly']
21 |
22 | def __init__(self, **kwargs):
23 | self._endpoint = kwargs.get(
24 | 'endpoint', 'https://lz4.overpass-api.de/api/interpreter')
25 | self._timeout = kwargs.get('timeout', 180)
26 | self._format = kwargs.get('format', 'json')
27 | self._poly = None
28 |
29 | def set_poly_from_list(self, poly_list: Iterable[tuple[float,
30 | float]]) -> None:
31 | """
32 | Set the polygon query boundary from a list of coordinates
33 |
34 | Args:
35 | poly_list (Iterable[tuple[float, float]]): List of floats in (lon, lat) format
36 | """
37 | self._poly = []
38 | for lon, lat in poly_list:
39 | self._poly.append((lon, lat))
40 |
41 | def set_poly_from_bbox(self, lat0: float, lon0: float, lat1: float,
42 | lon1: float) -> None:
43 | """
44 | Set Overpass Query Filter polygon by bounding box
45 |
46 | Args:
47 | lat0 (float): First latitude (can be min or max)
48 | lon0 (float): First longitude (can be min or max)
49 | lat1 (float): Second latitude (can be min or max)
50 | lon1 (float): Second longitude (can be min or max)
51 | """
52 | # Find max, min latitude and longitude
53 | lat_max = max((lat0, lat1))
54 | lat_min = min((lat0, lat1))
55 | lon_max = max((lon0, lon1))
56 | lon_min = min((lon0, lon1))
57 |
58 | # Set polygon
59 | self._poly = [(lon_min, lat_max), (lon_max, lat_max),
60 | (lon_max, lat_min), (lon_min, lat_min),
61 | (lon_min, lat_max)]
62 |
63 | def set_poly_from_wkt(self, wkt_str: str) -> None:
64 | """
65 | Set Overpass Query Filter polygon by WKT string
66 |
67 | Args:
68 | wkt_str (str): Input polygon in WKT format
69 |
70 | Raises:
71 | ValueError: If the input WKT string is not of a polygon
72 | """
73 | # Load WKT string into JSON dict
74 | geojson_dict = json.loads(
75 | ogr.CreateGeometryFromWkt(wkt_str).ExportToJson())
76 |
77 | # Sanity check geometry type
78 | if geojson_dict['type'] != 'Polygon':
79 | raise ValueError('Input WKT string must be a polygon! Got: %s' %
80 | geojson_dict['type'])
81 |
82 | # Set polygon (0 index is for linearRing inside polygon)
83 | self._poly = geojson_dict['coordinates'][0]
84 |
85 | @property
86 | def _poly_query_str(self) -> str:
87 | if self._poly is None:
88 | raise ValueError('Polygon not set!')
89 | return ' '.join([
90 | ' '.join((f'{lat:.6f}', f'{lon:.6f}')) for lon, lat in self._poly
91 | ])
92 |
93 | @property
94 | @abstractmethod
95 | def _query_str(self) -> str:
96 | return ''
97 |
98 | @property
99 | def endpoint(self) -> str:
100 | return self._endpoint
101 |
102 | def set_endpoint(self, endpoint: str) -> None:
103 | """ Set the Overpass API Endpoint """
104 | self._endpoint = endpoint
105 |
106 | def _perform_query(self) -> requests.models.Response:
107 | """ Perform an OSM Overpass API Request! """
108 | query_str = textwrap.dedent(self._query_str).replace('\n', '')
109 | result = requests.get(self.endpoint, params={'data': query_str})
110 | result.raise_for_status()
111 | return result
112 |
113 | def perform_query(self) -> OSMOverpassResult:
114 | """ Perform an OSM Overpass API Request! """
115 | return OSMOverpassResult(self._perform_query())
116 |
117 |
118 | class OSMOverpassResult:
119 | """ Container class for OSM Overpass result data. Will need to be subclassed to be useful. """
120 | __slots__ = ['_result']
121 |
122 | def __init__(self, result: requests.models.Response):
123 | self._result = result
124 |
125 | def to_file(self, output_path: Union[pathlib.Path, str]) -> None:
126 | """
127 | Have the query result to file. The file must be of the same format
128 | as the query result.
129 |
130 | Args:
131 | output_path (Union[pathlib.Path, str]): Output file path
132 |
133 | Raises:
134 | ValueError: if format is not recognized
135 | """
136 | # Get output path, sanity check dir
137 | output_path = pathlib.Path(output_path)
138 | assert output_path.parent.exists()
139 |
140 | # Assert output path suffix is valid
141 | this_format = self.format
142 | if 'json' in this_format:
143 | assert output_path.suffix.lower() == '.json'
144 | elif 'xml' in this_format:
145 | assert output_path.suffix.lower() in ('.xml', '.osm')
146 | else:
147 | raise ValueError('Unknown format: %s. Cannot export to file.' %
148 | str(this_format))
149 |
150 | # Save to file!
151 | with open(output_path, 'wb') as f:
152 | f.write(self._result.content)
153 |
154 | @property
155 | def format(self) -> str:
156 | if 'Content-Type' in self._result.headers:
157 | return self._result.headers['Content-Type'].split('/')[-1]
158 | raise ValueError(
159 | 'Unknown format! Content-Type not found in result header.')
160 |
--------------------------------------------------------------------------------
/rsc/osm/overpass_api/road_network.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import annotations
5 | import tempfile
6 | import pathlib
7 |
8 | from ..osm_network import OSMNetwork
9 | from .osm_overpass_api import OSMOverpassQuery, OSMOverpassResult
10 |
11 |
12 | class OSMRoadNetworkOverpassQuery(OSMOverpassQuery):
13 | """ Custom OSM Overpass API query for (hopefully) drivable road networks """
14 |
15 | __slots__ = ['_highway_tags']
16 |
17 | DEFAULT_HIGHWAY_TAGS = [
18 | 'motorway', 'motorway_link', 'motorway_junction', 'trunk',
19 | 'trunk_link', 'primary', 'primary_link', 'secondary', 'secondary_link',
20 | 'tertiary', 'tertiary_link', 'unclassified', 'residential'
21 | ]
22 |
23 | def __init__(self, **kwargs):
24 | super().__init__(**kwargs)
25 | self._highway_tags = kwargs.get('highway_tags',
26 | self.DEFAULT_HIGHWAY_TAGS)
27 |
28 | def perform_query(self) -> OSMRoadNetworkOverpassResult:
29 | """ Perform an OSM Overpass API Request! """
30 | return OSMRoadNetworkOverpassResult(self._perform_query())
31 |
32 | @property
33 | def _query_str(self) -> str:
34 | # NOTE: OSMNX uses instead
35 | # ["highway"!~"abandoned|bridleway|bus_guideway|construction|corridor|cycleway|elevator|escalator|footway|path|pedestrian|planned|platform|proposed|raceway|service|steps|track"]
36 |
37 | return f"""
38 | [out:{self._format}]
39 | [timeout:{self._timeout}];
40 | (way["highway"]
41 | ["area"!~"yes"]
42 | ["access"!~"private"]
43 | ["highway"~"{'|'.join(self._highway_tags)}"]
44 | ["motor_vehicle"!~"no"]
45 | ["motorcar"!~"no"]
46 | ["service"!~"alley|driveway|emergency_access|parking|parking_aisle|private"]
47 | (poly:'{self._poly_query_str}');
48 | >;
49 | );
50 | out;
51 | """
52 |
53 |
54 | class OSMRoadNetworkOverpassResult(OSMOverpassResult):
55 | """ Container class for OSM Overpass result data for (hopefully) drivable road networks """
56 |
57 | def to_network(self) -> OSMNetwork:
58 | """ Convert result data to an OSM network """
59 | with tempfile.TemporaryDirectory() as td:
60 | tmp_file = pathlib.Path(td, 'tmp')
61 | this_format = self.format
62 | if 'json' in this_format:
63 | tmp_file = tmp_file.with_suffix('.json')
64 | elif 'xml' in this_format:
65 | tmp_file = tmp_file.with_suffix('.osm')
66 | else:
67 | raise ValueError(
68 | 'Unknown format: \'%s\'! Cannot convert to OSM network.')
69 |
70 | # Save the response content to file
71 | tmp_file.write_bytes(self._result.content)
72 |
73 | # Load the network!
74 | n = OSMNetwork.from_file(tmp_file)
75 |
76 | return n
77 |
--------------------------------------------------------------------------------
/rsc/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jdalrym2/road_surface_classifier/e922314786a19bfcf55c60db5316dba86ee3b917/rsc/train/__init__.py
--------------------------------------------------------------------------------
/rsc/train/color_jitter_nohuesat.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Any, Dict, List, Optional, Tuple, Union, cast
3 |
4 | from torch import Tensor
5 |
6 | from kornia.augmentation import random_generator as rg
7 | from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
8 | from kornia.enhance import (adjust_brightness_accumulative,
9 | adjust_contrast_with_mean_subtraction)
10 |
11 |
12 | class ColorJitterNoHueSat(IntensityAugmentationBase2D):
13 | r"""Apply a random transformation to the brightness, contrast, saturation and hue of a tensor image.
14 |
15 | This implementation aligns PIL. Hence, the output is close to TorchVision.
16 |
17 | .. image:: _static/img/ColorJitter.png
18 |
19 | Args:
20 | p: probability of applying the transformation.
21 | brightness: The brightness factor to apply.
22 | contrast: The contrast factor to apply.
23 | silence_instantiation_warning: if True, silence the warning at instantiation.
24 | same_on_batch: apply the same transformation across the batch.
25 | keepdim: whether to keep the output shape the same as input (True) or broadcast it
26 | to the batch form (False).
27 | Shape:
28 | - Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`, Optional: :math:`(B, 3, 3)`
29 | - Output: :math:`(B, C, H, W)`
30 | """
31 |
32 | def __init__(
33 | self,
34 | brightness: Union[Tensor, float, Tuple[float, float],
35 | List[float]] = 0.0,
36 | contrast: Union[Tensor, float, Tuple[float, float], List[float]] = 0.0,
37 | same_on_batch: bool = False,
38 | p: float = 1.0,
39 | keepdim: bool = False,
40 | silence_instantiation_warning: bool = False,
41 | ) -> None:
42 | super().__init__(p=p,
43 | same_on_batch=same_on_batch,
44 | keepdim=keepdim)
45 |
46 | if not silence_instantiation_warning:
47 | warnings.warn(
48 | "`ColorJitter` is now following Torchvision implementation. Old "
49 | "behavior can be retrieved by instantiating `ColorJiggle`.",
50 | category=DeprecationWarning,
51 | )
52 |
53 | self.brightness = brightness
54 | self.contrast = contrast
55 | self._param_generator = cast(
56 | rg.ColorJitterGenerator,
57 | rg.ColorJitterGenerator(brightness, contrast))
58 |
59 | def apply_transform(self,
60 | input: Tensor,
61 | params: Dict[str, Tensor],
62 | flags: Dict[str, Any],
63 | transform: Optional[Tensor] = None) -> Tensor:
64 |
65 | transforms = [
66 | lambda img: adjust_brightness_accumulative(
67 | img, params["brightness_factor"]),
68 | lambda img: adjust_contrast_with_mean_subtraction(
69 | img, params["contrast_factor"])
70 | ]
71 |
72 | jittered = input
73 | for idx in (0, 1):
74 | t = transforms[idx]
75 | jittered = t(jittered)
76 |
77 | return jittered
78 |
--------------------------------------------------------------------------------
/rsc/train/data_augmentation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | from typing import Tuple
4 | import torch
5 | import torch.nn as nn
6 | import kornia
7 |
8 | from .color_jitter_nohuesat import ColorJitterNoHueSat
9 |
10 |
11 | class DataAugmentation(nn.Module):
12 |
13 | def __init__(self, has_nir: bool=True):
14 | super().__init__()
15 |
16 | # Are we augmenting NIR as well?
17 | self.has_nir = has_nir
18 |
19 | # Random positional transformations
20 | self.transform_flip = nn.Sequential(
21 | kornia.augmentation.RandomHorizontalFlip(p=0.5),
22 | kornia.augmentation.RandomVerticalFlip(p=0.5))
23 |
24 | # Random offset of mask to re-inforce proper segmentation
25 | # when labels may be inaccurate
26 | self.transform_offset = nn.Sequential(
27 | kornia.augmentation.RandomAffine(degrees=(-15, 15),
28 | translate=(0.0625, 0.0625),
29 | p=0.5))
30 |
31 | # Transform RGB
32 | self.transform_color = nn.Sequential(
33 | kornia.augmentation.RandomPlasmaBrightness(roughness=(0.1, 0.5),
34 | intensity=(0.1, 0.3)),
35 | kornia.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1))
36 |
37 | # Transform NIR
38 | self.transform_nir = nn.Sequential(ColorJitterNoHueSat(0.1, 0.1))
39 |
40 | @torch.no_grad()
41 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
42 |
43 | # Apply position transform to image + masks
44 | x_aug = self.transform_flip(x)
45 |
46 | # Break down what all of the channels are
47 | mask_c = slice(4, 5) if self.has_nir else slice(3, 4)
48 | probmask_c = slice(5, 6) if self.has_nir else slice(4, 5)
49 |
50 | # RGB color transformation
51 | im_aug = self.transform_color(x_aug[:, 0:3, ...])
52 |
53 | if self.has_nir:
54 | # NIR color transformation
55 | im_nir_aug = self.transform_nir(x_aug[:, 3:4, ...])
56 |
57 | # Combine NIR with RGB image
58 | im_aug = torch.concat((im_aug, im_nir_aug), dim=1)
59 |
60 | # Apply offset transform to mask
61 | mask_aug = self.transform_offset(x_aug[:, mask_c, ...])
62 |
63 | # Combine into color image + location mask
64 | im_aug = torch.concat((im_aug, mask_aug), dim=1)
65 |
66 | # Extract probmask for training
67 | pm_aug = x_aug[:, probmask_c, :, :]
68 |
69 | return im_aug, pm_aug
70 |
--------------------------------------------------------------------------------
/rsc/train/dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | from typing import Any
4 | import pandas as pd
5 | import numpy as np
6 | import PIL.Image
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 |
11 | from torchvision.transforms import CenterCrop
12 |
13 | import torch.multiprocessing
14 |
15 | # Prevents file descriptor errors when doing multiprocessing fetches
16 | torch.multiprocessing.set_sharing_strategy('file_system')
17 |
18 |
19 | class RoadSurfaceDataset(Dataset):
20 |
21 | def __init__(self,
22 | df_path,
23 | transform,
24 | chip_size=224,
25 | n_channels=4,
26 | limit=-1):
27 |
28 | # Load dataframe, interpret lenght
29 | self.df = pd.read_csv(df_path)
30 | self.n_idxs = len(self.df) if limit == -1 else min(limit, len(self.df))
31 |
32 | # Number of channels in the image
33 | self.n_channels = n_channels
34 |
35 | # Number of classes
36 | self.n_classes = self.df['class_num'].max() + 1
37 |
38 | # Chip size
39 | self.set_chip_size(chip_size)
40 |
41 | # Transformation object
42 | self.transform = transform
43 |
44 | def set_chip_size(self, chip_size: int):
45 | self.cc = CenterCrop(chip_size)
46 |
47 | def __len__(self):
48 | return self.n_idxs
49 |
50 | def __getitem__(self, idx):
51 |
52 | # Get row
53 | row: Any = self.df.iloc[idx] # type: ignore
54 |
55 | # Image
56 | with PIL.Image.open(row.chip_path) as pim:
57 | im = np.array(self.cc(pim))
58 | tn = im.shape[2] if im.ndim == 3 else 1
59 | if tn > 1:
60 | # Adds support for loading imagery >
61 | # n_channels. In this case we just grab
62 | # the required number of channels
63 | im = im[..., :self.n_channels]
64 | tn = im.shape[2]
65 | if tn > self.n_channels:
66 | print(
67 | f'WARNING: Got {im.shape[2]} channel image but model only has {self.n_channels} dimensions!'
68 | )
69 | im = im[:, :, :self.n_channels]
70 |
71 | # Mask
72 | with PIL.Image.open(row.mask_path) as pmask:
73 | mask = np.array(self.cc(pmask))
74 | if mask.ndim == 2:
75 | mask = mask[:, :, np.newaxis]
76 | tn = im.shape[2]
77 | if tn > 1:
78 | mask = mask[:, :, 0][:, :, np.newaxis]
79 |
80 | # Prob mask
81 | with PIL.Image.open(row.probmask_path) as pmask:
82 | probmask = np.array(self.cc(pmask))
83 | if probmask.ndim == 2:
84 | probmask = probmask[:, :, np.newaxis]
85 | tn = im.shape[2]
86 | if tn > 1:
87 | probmask = probmask[:, :, 0][:, :, np.newaxis]
88 |
89 | # Label (add one to number of classes to account for obscurations)
90 | lbl = [0] * (self.n_classes + 1)
91 |
92 | # Get class idx
93 | c = int(row.class_num)
94 |
95 | # Compute obscuration estimate
96 | obsc = 1 - (mask * (probmask > 127)).sum() / mask.sum()
97 |
98 | # Set labels accordingly
99 | # NOTE: no longer fuzzy
100 | lbl[c] = 1
101 | lbl[-1] = obsc
102 |
103 | # Concat image and masks for output
104 | x = self.transform(np.concatenate((im, mask, probmask), axis=2))
105 |
106 | return x, torch.Tensor(lbl)
107 |
--------------------------------------------------------------------------------
/rsc/train/mcnn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as functional
6 | import torchvision
7 |
8 | # NOTE: Patch for MaxUnpool2d needs to be applied
9 | # when exporting to ONNX
10 | from torch.nn import MaxUnpool2d
11 | #from patch import MaxUnpool2d
12 |
13 |
14 | class Freezable:
15 |
16 | def parameters(self):
17 | # This is overridden
18 | return []
19 |
20 | def freeze(self):
21 | self._set_freeze(True)
22 |
23 | def unfreeze(self):
24 | self._set_freeze(False)
25 |
26 | def _set_freeze(self, v):
27 | req_grad = not bool(v)
28 | for param in self.parameters():
29 | param.requires_grad = req_grad
30 |
31 |
32 | class FreezableModule(nn.Module, Freezable):
33 | pass
34 |
35 |
36 | class FreezableModuleList(nn.ModuleList, Freezable):
37 | pass
38 |
39 |
40 | class FreezableAdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, Freezable):
41 | pass
42 |
43 |
44 | class FreezableLinear(nn.Linear, Freezable):
45 | pass
46 |
47 |
48 | class DecoderBlock(FreezableModule):
49 |
50 | def __init__(self, in_channels, out_channels, kernel_size=3):
51 | super().__init__()
52 |
53 | self.upconv = nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
54 | self.conv1 = nn.Conv2d(out_channels * 2, out_channels, kernel_size)
55 | self.relu = nn.ReLU(inplace=True)
56 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size)
57 |
58 | def forward(self, x, encoder_feats=None):
59 |
60 | x = self.upconv(x)
61 |
62 | enc_ftrs = self.crop(encoder_feats, x)
63 | x = torch.concat((x, enc_ftrs), dim=1)
64 |
65 | x = self.conv1(x)
66 | x = self.relu(x)
67 | x = self.conv2(x)
68 | x = self.relu(x)
69 | return x
70 |
71 | def crop(self, enc_ftrs, x):
72 | _, _, h, w = x.shape
73 | enc_ftrs = torchvision.transforms.CenterCrop([h, w])(enc_ftrs)
74 | return enc_ftrs
75 |
76 |
77 | class Resnet18Encoder(FreezableModule):
78 |
79 | def __init__(self, in_channels=3):
80 | super().__init__()
81 |
82 | # Get Resnet18 w/ default weights
83 | self.rnet = torchvision.models.resnet18(
84 | weights=torchvision.models.ResNet18_Weights.DEFAULT)
85 |
86 | if in_channels != 3:
87 | self.rnet.conv1 = nn.Conv2d(in_channels,
88 | 64,
89 | kernel_size=(7, 7),
90 | stride=(2, 2),
91 | padding=(3, 3),
92 | bias=False)
93 |
94 | self.rnet.maxpool.return_indices = True
95 |
96 | # Delete final avg pool / linear layer
97 | del self.rnet.avgpool
98 | del self.rnet.fc
99 |
100 | # Features for cross-connections to decoder
101 | self.feats = [
102 | torch.Tensor([]),
103 | torch.Tensor([]),
104 | torch.Tensor([]),
105 | torch.Tensor([])
106 | ]
107 | self.maxpool_idxs = None
108 |
109 | def forward(self, x):
110 | # Building up to stuff
111 | x = self.rnet.conv1(x)
112 | x = self.rnet.bn1(x)
113 | x = self.rnet.relu(x)
114 | x, self.maxpool_idxs = self.rnet.maxpool(x)
115 |
116 | # Now the main layers
117 | self.feats[0] = x
118 | x = self.rnet.layer1(x)
119 |
120 | self.feats[1] = x
121 | x = self.rnet.layer2(x)
122 |
123 | self.feats[2] = x
124 | x = self.rnet.layer3(x)
125 |
126 | self.feats[3] = x
127 | x = self.rnet.layer4(x)
128 |
129 | return x
130 |
131 |
132 | class Resnet18Decoder(FreezableModuleList):
133 |
134 | def __init__(self):
135 | super().__init__()
136 |
137 | self.layer_1 = DecoderBlock(512, 256)
138 | self.layer_2 = DecoderBlock(256, 128)
139 | self.layer_3 = DecoderBlock(128, 64)
140 | self.layer_4 = DecoderBlock(64, 64, kernel_size=1)
141 |
142 | self.unpool = MaxUnpool2d(kernel_size=3, stride=2, padding=1)
143 | self.final_1 = nn.Conv2d(64, 2, 3)
144 | self.relu = nn.ReLU(inplace=True)
145 | self.final_2 = nn.Conv2d(2, 2, 3)
146 | self.softmax = nn.Softmax(dim=1)
147 |
148 | def forward(self, x, encoder_feats, maxpool_idxs):
149 |
150 | x = self.layer_1(x, encoder_feats[3])
151 | x = self.layer_2(x, encoder_feats[2])
152 | x = self.layer_3(x, encoder_feats[1])
153 | x = self.layer_4(x, encoder_feats[0])
154 | x = self.unpool(x, maxpool_idxs, output_size=(112, 112))
155 |
156 | x = self.final_1(x)
157 | x = self.relu(x)
158 | x = self.final_2(x)
159 | x = functional.interpolate(x, (224, 224))
160 | x = self.softmax(x)
161 |
162 | return x
163 |
164 |
165 | class MaskCNN(nn.Module):
166 |
167 | def __init__(self, num_classes=2, num_channels=5):
168 | super().__init__()
169 |
170 | # Segmentation stage
171 | self.encoder = Resnet18Encoder(in_channels=num_channels)
172 | self.decoder = Resnet18Decoder()
173 |
174 | # Classification stage
175 | self.encoder2 = Resnet18Encoder(in_channels=num_channels)
176 | self.avgpool = FreezableAdaptiveAvgPool2d(output_size=(1, 1))
177 | self.fc = FreezableLinear(512, num_classes, bias=True)
178 |
179 | def forward(self, x):
180 |
181 | # Image -> Features
182 | y = self.encoder(x)
183 |
184 | # Features -> Mask
185 | y = self.decoder(y, self.encoder.feats, self.encoder.maxpool_idxs)
186 | y = y[:, 0:1, ...]
187 |
188 | # Adjust image from mask: we only fetch the former
189 | # NOTE: - RGB: image is (0, 1, 2). Input mask is (3,),
190 | # - RGB + NIR image is (0, 1, 2, 3). Input mask is (4,),
191 | # - In both cases this is the image is :-1, mask is -1
192 | # NOTE: the paper recommends multiplication, but in this case
193 | # concat-ing the segmentation mask seems to produce better results
194 | # x = torch.multiply(x[:, :-1, ...], y)
195 | x = torch.concat((x[:, :-1, ...], y), dim=1)
196 |
197 | # Updated Image -> Features
198 | x = self.encoder2(x)
199 | x = self.avgpool(x)
200 | x = torch.flatten(x, 1)
201 | z = self.fc(x)
202 |
203 | return y, z
--------------------------------------------------------------------------------
/rsc/train/mcnn_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as functional
7 |
8 | from .rsc_hxe_loss import RSCHXELoss
9 |
10 |
11 | class MCNNLoss(nn.Module):
12 | """ Combined MaskCNN loss function """
13 |
14 | def __init__(self, top_lv_map, class_weights, seg_k, ob_k):
15 | super().__init__()
16 |
17 | # Inputs
18 | self.top_lv_map = torch.IntTensor(top_lv_map).cuda()
19 | self.class_weights = torch.Tensor(class_weights).float().cuda()
20 | self.seg_k = seg_k
21 | self.ob_k = ob_k
22 |
23 | # Loss functions
24 | self.hxe_loss = RSCHXELoss(self.top_lv_map, self.class_weights)
25 | self.o_loss = torch.nn.BCEWithLogitsLoss(reduction='mean')
26 |
27 | self.seg_loss = 0.
28 | self.cl_loss = 0.
29 | self.ob_loss = 0.
30 | self.stage = 0
31 |
32 | def forward(self, y_hat: torch.Tensor, y: torch.Tensor,
33 | z_hat: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
34 | """
35 | Compute the loss given the model's segmentation and classification results.
36 |
37 | Segmentation loss is DICE + BCE loss, inspired from:
38 | https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook#BCE-Dice-Loss
39 |
40 | Args:
41 | y_hat (torch.Tensor): Segmentation model result
42 | y (torch.Tensor): Segmentation truth
43 | z_hat (torch.Tensor): Classification model result
44 | z (torch.Tensor): Classification truth
45 |
46 | Returns:
47 | torch.Tensor: Computed loss for the model
48 | """
49 |
50 | # Loss 1: Dice BCE loss for segmentation (seg_loss)
51 | if self.stage in (0, 1):
52 | intersection = (y_hat * y).sum()
53 | smooth = 1 # Dice BCE smoothing parameter, hardcoded to 1 does just fine
54 | dice_loss = 1 - (2. * intersection + 1) / (
55 | y_hat.sum() + y.sum() + smooth)
56 | binary_ce = functional.binary_cross_entropy(y_hat,
57 | y,
58 | reduction='mean')
59 | self.seg_loss = binary_ce + dice_loss
60 | else:
61 | self.seg_loss = 0.
62 |
63 | # Loss 2: BCE Loss for estimating road obscuration (ob_loss)
64 | # This operates on the last logit produced by the model
65 | if self.stage in (0, 2):
66 | self.ob_loss = self.o_loss(z_hat[:, -1], z[:, -1])
67 | else:
68 | self.ob_loss = 0.
69 |
70 | # Loss 3: Cross entropy for classification result (cl_loss)
71 | # This operates on all but the last model logit
72 | if self.stage in (0, 2):
73 | self.cl_loss = self.hxe_loss(z_hat[:, :-1], z[:, :-1])
74 | else:
75 | self.cl_loss = 0.
76 |
77 | # Combine and return the combined loss
78 | loss = self.seg_k * self.seg_loss + self.ob_k * self.ob_loss + self.cl_loss
79 | return loss
--------------------------------------------------------------------------------
/rsc/train/patch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.autograd import Function
6 | from torch.nn.modules.pooling import _MaxUnpoolNd
7 | from torch.nn.modules.utils import _pair
8 |
9 |
10 | class MaxUnpool2dop(Function):
11 | """We warp the `torch.nn.functional.max_unpool2d`
12 | with an extra `symbolic` method, which is needed while exporting to ONNX.
13 | Users should not call this function directly.
14 | """
15 |
16 | @staticmethod
17 | def forward(ctx, input, indices, kernel_size, stride, padding,
18 | output_size):
19 | """Forward function of MaxUnpool2dop.
20 | Args:
21 | input (Tensor): Tensor needed to upsample.
22 | indices (Tensor): Indices output of the previous MaxPool.
23 | kernel_size (Tuple): Size of the max pooling window.
24 | stride (Tuple): Stride of the max pooling window.
25 | padding (Tuple): Padding that was added to the input.
26 | output_size (List or Tuple): The shape of output tensor.
27 | Returns:
28 | Tensor: Output tensor.
29 | """
30 | return F.max_unpool2d(input, indices, kernel_size, stride, padding,
31 | output_size)
32 |
33 | @staticmethod
34 | def symbolic(g, input, indices, kernel_size, stride, padding, output_size):
35 | # get shape
36 | input_shape = g.op('Shape', input)
37 | const_0 = g.op('Constant', value_t=torch.tensor(0))
38 | const_1 = g.op('Constant', value_t=torch.tensor(1))
39 | batch_size = g.op('Gather', input_shape, const_0, axis_i=0)
40 | channel = g.op('Gather', input_shape, const_1, axis_i=0)
41 |
42 | # height = (height - 1) * stride + kernel_size
43 | height = g.op('Gather',
44 | input_shape,
45 | g.op('Constant', value_t=torch.tensor(2)),
46 | axis_i=0)
47 | height = g.op('Sub', height, const_1)
48 | height = g.op('Mul', height,
49 | g.op('Constant', value_t=torch.tensor(stride[1])))
50 | height = g.op('Add', height,
51 | g.op('Constant', value_t=torch.tensor(kernel_size[1])))
52 |
53 | # width = (width - 1) * stride + kernel_size
54 | width = g.op('Gather',
55 | input_shape,
56 | g.op('Constant', value_t=torch.tensor(3)),
57 | axis_i=0)
58 | width = g.op('Sub', width, const_1)
59 | width = g.op('Mul', width,
60 | g.op('Constant', value_t=torch.tensor(stride[0])))
61 | width = g.op('Add', width,
62 | g.op('Constant', value_t=torch.tensor(kernel_size[0])))
63 |
64 | # step of channel
65 | channel_step = g.op('Mul', height, width)
66 | # step of batch
67 | batch_step = g.op('Mul', channel_step, channel)
68 |
69 | # channel offset
70 | range_channel = g.op('Range', const_0, channel, const_1)
71 | range_channel = g.op(
72 | 'Reshape', range_channel,
73 | g.op('Constant', value_t=torch.tensor([1, -1, 1, 1])))
74 | range_channel = g.op('Mul', range_channel, channel_step)
75 | range_channel = g.op('Cast', range_channel, to_i=7) # 7 is int64
76 |
77 | # batch offset
78 | range_batch = g.op('Range', const_0, batch_size, const_1)
79 | range_batch = g.op(
80 | 'Reshape', range_batch,
81 | g.op('Constant', value_t=torch.tensor([-1, 1, 1, 1])))
82 | range_batch = g.op('Mul', range_batch, batch_step)
83 | range_batch = g.op('Cast', range_batch, to_i=7) # 7 is int64
84 |
85 | # update indices
86 | indices = g.op('Add', indices, range_channel)
87 | indices = g.op('Add', indices, range_batch)
88 |
89 | return g.op('MaxUnpool',
90 | input,
91 | indices,
92 | kernel_shape_i=kernel_size,
93 | strides_i=stride)
94 |
95 |
96 | class MaxUnpool2d(_MaxUnpoolNd):
97 | """This module is modified from Pytorch `MaxUnpool2d` module.
98 | Args:
99 | kernel_size (int or tuple): Size of the max pooling window.
100 | stride (int or tuple): Stride of the max pooling window.
101 | Default: None (It is set to `kernel_size` by default).
102 | padding (int or tuple): Padding that is added to the input.
103 | Default: 0.
104 | """
105 |
106 | def __init__(self, kernel_size, stride=None, padding=0):
107 | super(MaxUnpool2d, self).__init__()
108 | self.kernel_size = _pair(kernel_size)
109 | self.stride = _pair(stride or kernel_size)
110 | self.padding = _pair(padding)
111 |
112 | def forward(self, input, indices, output_size=None):
113 | """Forward function of MaxUnpool2d.
114 | Args:
115 | input (Tensor): Tensor needed to upsample.
116 | indices (Tensor): Indices output of the previous MaxPool.
117 | output_size (List or Tuple): The shape of output tensor.
118 | Default: None.
119 | Returns:
120 | Tensor: Output tensor.
121 | """
122 | return MaxUnpool2dop.apply(input, indices, self.kernel_size,
123 | self.stride, self.padding, output_size)
--------------------------------------------------------------------------------
/rsc/train/plmcnn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import torch
5 | import optuna
6 | import pytorch_lightning as pl
7 |
8 | from .data_augmentation import DataAugmentation
9 |
10 | from .mcnn import MaskCNN
11 | from .mcnn_loss import MCNNLoss
12 |
13 |
14 | class PLMaskCNN(pl.LightningModule):
15 | """ PyTorch Lightning wrapper for training the
16 | RSC MaskCNN """
17 |
18 | def __init__(self,
19 | labels,
20 | top_level_map,
21 | weights,
22 | learning_rate: float = 1e-4,
23 | seg_k: float = 1.0,
24 | ob_k: float = 1.0,
25 | nc: int = 4):
26 | super().__init__()
27 |
28 | # Hyperparameters
29 | self.learning_rate = learning_rate
30 | self.seg_k = seg_k
31 | self.ob_k = ob_k
32 | self.labels = labels
33 | self.top_level_map = top_level_map
34 | self.weights = weights
35 | self.save_hyperparameters()
36 |
37 | # Number of channels
38 | self.nc = nc
39 |
40 | # Optuna trial (use set_optuna_trial)
41 | self.trial: optuna.trial.Trial | None = None
42 |
43 | # Stateful min val_loss_cl
44 | self.min_val_loss = float('inf')
45 | self.min_val_loss_im = float('inf')
46 | self.min_val_loss_cl = float('inf')
47 | self.min_val_loss_ob = float('inf')
48 |
49 | # Stateful learning rate
50 | self._lr = learning_rate
51 |
52 | self.transform = DataAugmentation(has_nir=(nc == 4))
53 | self.loss = MCNNLoss(self.top_level_map, self.weights,
54 | self.seg_k, self.ob_k)
55 |
56 | # Labels: add 1 for "obscuartion"
57 | # Channels: add 1 for "mask" (e.g. RGB + mask, RGB + NIR + mask)
58 | self.model = MaskCNN(num_classes=len(self.labels) + 1,
59 | num_channels=nc + 1)
60 |
61 | def set_optuna_trial(self, trial: optuna.trial.Trial | None):
62 | self.trial = trial
63 |
64 | def set_stage(self, v, lr):
65 | first_stage = (self.model.encoder, self.model.decoder)
66 | second_stage = (self.model.encoder2, self.model.avgpool, self.model.fc)
67 |
68 | # Freeze / unfreeze components based on stage
69 | if v == 0:
70 | [e.unfreeze() for e in first_stage]
71 | [e.unfreeze() for e in second_stage]
72 | elif v == 1:
73 | [e.unfreeze() for e in first_stage]
74 | [e.freeze() for e in second_stage]
75 | elif v == 2:
76 | [e.freeze() for e in first_stage]
77 | [e.unfreeze() for e in second_stage]
78 | else:
79 | raise ValueError(f'Unknown v: {repr(v):s}')
80 |
81 | # Loss function requires stage
82 | self.loss.stage = v
83 |
84 | # Learning rate depends on stage
85 | self._lr = lr
86 |
87 | # Set stage
88 | self.stage = v
89 |
90 | def forward(self, x):
91 | return self.model(x)
92 |
93 | def training_step(self, batch, batch_idx):
94 | x, z = batch
95 | x, xpm = self.transform(x)
96 | y_hat, z_hat = self.forward(x)
97 | loss = self.loss(y_hat, xpm, z_hat, z)
98 | self.log_dict(
99 | {
100 | 'train_loss_im': self.loss.seg_loss,
101 | 'train_loss_cl': self.loss.cl_loss,
102 | 'train_loss_ob': self.loss.ob_loss,
103 | 'train_loss': loss,
104 | },
105 | on_step=False,
106 | on_epoch=True)
107 | return loss
108 |
109 | def validation_step(self, batch, batch_idx):
110 | x, z = batch
111 |
112 | img_mask_c = slice(0, self.nc + 1)
113 | probmask_c = slice(self.nc + 1, self.nc + 2)
114 |
115 | # Create probmask, image + mask (be careful, order matters!)
116 | y = x[:, probmask_c, :, :]
117 | x = x[:, img_mask_c, :, :]
118 |
119 | y_hat, z_hat = self.forward(x)
120 | loss = self.loss(y_hat, y, z_hat, z)
121 | self.log_dict(
122 | {
123 | 'val_loss_im': self.loss.seg_loss,
124 | 'val_loss_cl': self.loss.cl_loss,
125 | 'val_loss_ob': self.loss.ob_loss,
126 | 'val_loss': loss,
127 | },
128 | on_step=False,
129 | on_epoch=True)
130 | return loss
131 |
132 | def on_validation_epoch_end(self):
133 |
134 | metrics = self.trainer.logged_metrics
135 | this_val_loss = float(metrics['val_loss'])
136 | this_val_loss_im = float(metrics['val_loss_im'])
137 | this_val_loss_cl = float(metrics['val_loss_cl'])
138 | this_val_loss_ob = float(metrics['val_loss_ob'])
139 |
140 | if this_val_loss < self.min_val_loss:
141 | self.min_val_loss = this_val_loss
142 | self.min_val_loss_im = this_val_loss_im
143 | self.min_val_loss_cl = this_val_loss_cl
144 | self.min_val_loss_ob = this_val_loss_ob
145 |
146 | self.log_dict({
147 | 'min_val_loss_im': self.min_val_loss_im,
148 | 'min_val_loss_cl': self.min_val_loss_cl,
149 | 'min_val_loss_ob': self.min_val_loss_ob,
150 | 'min_val_loss': self.min_val_loss,
151 | })
152 |
153 | if self.trial is not None:
154 | self.trial.report(this_val_loss_cl, self.current_epoch)
155 | if self.trial.should_prune():
156 | raise optuna.exceptions.TrialPruned()
157 |
158 | def configure_optimizers(self):
159 | return torch.optim.Adam(self.parameters(), lr=self._lr)
160 |
--------------------------------------------------------------------------------
/rsc/train/preprocess.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import kornia
7 |
8 |
9 | class PreProcess(nn.Module):
10 |
11 | def __init__(self):
12 | super().__init__()
13 | self.resize = kornia.augmentation.Resize(224, keepdim=True)
14 |
15 | @torch.no_grad()
16 | def forward(self, x) -> torch.Tensor:
17 |
18 | # NOTE: applies to full 6-channel input (image + mask + probmask)
19 |
20 | # Convert to tensor
21 | x = kornia.utils.image_to_tensor(x, keepdim=True).float()
22 | x = self.resize(x)
23 |
24 | # Normalize between 0 and 1
25 | x = torch.divide(x, 255.)
26 |
27 | return x
--------------------------------------------------------------------------------
/rsc/train/rsc_hxe_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """ Hierarchical Cross Entropy Loss """
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class RSCHXELoss(nn.Module):
11 | """ Hierarchical Cross Entropy Loss (tailored to RSC model to handle obscuration estimates)"""
12 |
13 | def __init__(self, lv_b_a: torch.Tensor, lv_b_w: torch.Tensor):
14 | """
15 | Init the loss function, describing a hierarchy
16 |
17 | (A1) (A2) <---- "Level A" | i.e. top level classes; e.g. 'paved' vs 'unpaved') |
18 | // \\ // \\
19 | (B1) (B2) (B3) (B4) <---- "Level B" | i.e. prediction classes |
20 | | e.g. 'asphalt' vs 'concrete' vs 'dirt' vs 'gravel') |
21 |
22 | Args:
23 | lv_b_a (torch.Tensor): Mapping for each label to index of top-level (level A) hierarchy index
24 | e.g. from above example: (0, 0, 1, 1)
25 | lv_b_w (torch.Tensor): Weights for prediction (level B) classes, independent of hierarchy
26 | """
27 |
28 | super().__init__()
29 |
30 | # Sanity check
31 | assert len(lv_b_a) == len(lv_b_w)
32 | device = lv_b_a.device
33 |
34 | # Persist Level B -> Level A mapping
35 | self.lv_b_a = lv_b_a
36 |
37 | # List of indices between Lv A elements and Lv B elements
38 | # NOTE: assumes we have a lv B node attached to every lv A node
39 | self.lv_a_idx = [torch.where(lv_b_a == e)[0] for e in range(int(lv_b_a.max() + 1))]
40 |
41 | # Proportion of Level A elements among level B
42 | lv_a_p = torch.tensor([(1 / len(lv_b_w) / lv_b_w[idx]).sum() for idx in self.lv_a_idx])
43 |
44 | # Level A weights
45 | self.lv_a_w = ((1 / len(lv_a_p)) / lv_a_p).to(device)
46 |
47 | # Level B weights for each node in level A
48 | self.lv_b_w_a = torch.concat([(1 / lv_b_w[idx]).sum() * lv_b_w[idx] / len(idx) for idx in self.lv_a_idx])
49 |
50 | def forward(self, logits: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
51 | """
52 | Compute the loss given the model's classification results.
53 |
54 | Args:
55 | logits (torch.Tensor): Classification model output
56 | truth (torch.Tensor): Truth classification (one-hot encoded) Shape: (N, num_lv_b)
57 |
58 | Returns:
59 | torch.Tensor: Computed loss for the model
60 | """
61 |
62 | # Compute the logarithm of the predicted probabilities
63 | log_y_pred = F.log_softmax(logits, dim=1)
64 |
65 | # Initialize the loss
66 | loss_lv_b = torch.Tensor((0,)).to(logits.device)
67 | loss_lv_a = torch.Tensor((0,)).to(logits.device)
68 |
69 | # Iterate over the categories
70 | for i in range(truth.shape[1]):
71 | # Compute the weight for the category
72 | w = self.lv_b_w_a[i]
73 |
74 | # Compute the weight for the upper layer
75 | w2 = self.lv_a_w[self.lv_b_a[i]]
76 |
77 | # Compute the cross entropy loss for the category
78 | cross_entropy = -torch.sum(truth[:, i] * log_y_pred[:, i])
79 |
80 | # Add the weighted cross entropy loss to the total loss
81 | loss_lv_b += w * w2 * cross_entropy
82 |
83 | loss_lv_b /= truth.shape[1]
84 |
85 | # Compute level A logits
86 | logits_lv_a = torch.stack([torch.sum(logits[:, e], 1) for e in self.lv_a_idx], -1)
87 | log_y_pred_lv_a = F.log_softmax(logits_lv_a, dim=1)
88 | truth_lv_a = torch.stack([torch.sum(truth[:, e], 1) for e in self.lv_a_idx], -1)
89 |
90 | # Iterate over the categories
91 | for i in range(truth_lv_a.shape[1]):
92 | # Compute the weight for the category
93 | w = self.lv_a_w[i]
94 |
95 | # Compute the cross entropy loss for the category
96 | cross_entropy = -torch.sum(truth_lv_a[:, i] * log_y_pred_lv_a[:, i])
97 |
98 | # Add the weighted cross entropy loss to the total loss
99 | loss_lv_a += w * cross_entropy
100 |
101 | loss_lv_a /= truth_lv_a.shape[1]
102 |
103 | return (loss_lv_b + loss_lv_a) / truth.shape[0]
--------------------------------------------------------------------------------
/scripts/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # type: ignore
4 | """ Road Surface Classifier Inference Example
5 | This is a quick script to perform inference with our trained model, so we can see how well it works!
6 | """
7 |
8 | # %%
9 |
10 | import pathlib
11 | from tqdm import tqdm
12 |
13 | import torch
14 | from torch.utils.data import DataLoader
15 |
16 | from rsc.inference.mass_inference_dataset import MassInferenceDataset
17 | from rsc.model.preprocess import PreProcess
18 |
19 | # Input model and checkpoint paths (checkpoint contains the weights for inference)
20 | # Since these files are so large, they are not in source control.
21 | # Reach out to me if you'd like them.
22 | ckpt_path = pathlib.Path(
23 | '/data/road_surface_classifier/results/20230107_042006Z/model-0-epoch=10-val_loss=0.39906.ckpt'
24 | )
25 | assert ckpt_path.exists()
26 | results_name = ckpt_path.parent.name
27 | ds_path = pathlib.Path('/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019.sqlite3')
28 | assert ds_path.exists()
29 |
30 | # %% Setup processing chain for inference
31 | from rsc.model.plmcnn import PLMaskCNN
32 |
33 | # Load model and set to eval mode
34 | model = PLMaskCNN.load_from_checkpoint(ckpt_path)
35 | model.eval()
36 |
37 | # Get label array (it's built into model)
38 | labels = model.__dict__.get('labels')
39 |
40 | # Import dataset
41 | preprocess = PreProcess()
42 | val_ds = MassInferenceDataset(ds_path, transform=preprocess)
43 | batch_size = 64
44 | val_dl = DataLoader(val_ds,
45 | num_workers=16,
46 | batch_size=batch_size,
47 | shuffle=True)
48 |
49 | #%% Iterate over the dataloader, and fetch predictions
50 |
51 | # Output data array
52 | output = []
53 |
54 | for i, (osm_id, x) in tqdm(enumerate(iter(val_dl)),
55 | total=len(val_ds) // batch_size):
56 |
57 | # Get size of this batch
58 | sz = x.shape[0]
59 |
60 | # Predict with the model
61 | _, y_pred = model(x)
62 |
63 | # Compute argmax
64 | y_pred_am = torch.argmax(y_pred[:, 0:-1], dim=1)
65 | y_pred_am = y_pred_am.detach().numpy() # type: ignore
66 |
67 | # Compute softmax
68 | y_pred_sm = torch.softmax(y_pred, dim=1)
69 | y_pred_sm = y_pred_sm.detach().numpy() # type: ignore
70 |
71 | for j in range(sz):
72 | output.append((int(osm_id[j]), labels[y_pred_am[j]], *y_pred_sm[j, :]))
73 |
74 | #%% Save output as CSV
75 | import pandas as pd
76 |
77 | columns = ['osm_id', 'pred_label', *['pred_%s' % label for label in labels]]
78 |
79 | df = pd.DataFrame(output, columns=columns).set_index('osm_id')
80 | df.to_csv(
81 | f'/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019_results_{results_name}.csv')
82 |
83 | #%% Load in CSV output, and merge with original dataset SQLite file
84 | import sqlite3
85 | import pandas as pd
86 |
87 | df = pd.read_csv(
88 | f'/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019_results_{results_name}.csv'
89 | ).set_index('osm_id')
90 |
91 | with sqlite3.connect(
92 | 'file:/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019.sqlite3?mode=ro',
93 | uri=True) as con:
94 | df2 = pd.read_sql('SELECT * FROM features;', con).set_index('osm_id')
95 |
96 | df = df.join(df2)
97 | df
98 |
99 | #%% Save results into GPKG file for import to QGIS
100 | from osgeo import gdal, ogr, osr
101 |
102 | gdal.UseExceptions()
103 | ogr.UseExceptions()
104 |
105 | # Create SRS (EPSG:4326: WGS-84 decimal degrees)
106 | srs = osr.SpatialReference()
107 | srs.ImportFromEPSG(4326)
108 |
109 | driver: ogr.Driver = ogr.GetDriverByName('GPKG')
110 | ds: ogr.DataSource = driver.CreateDataSource(
111 | f'/data/road_surface_classifier/BOULDER_COUNTY_NAIP_2019_results_{results_name}.gpkg'
112 | )
113 | layer: ogr.Layer = ds.CreateLayer(
114 | 'data', srs=srs, geom_type=ogr.wkbLineString) # type: ignore
115 |
116 | osm_id_field = ogr.FieldDefn('osm_id', ogr.OFTInteger64)
117 | highway_field = ogr.FieldDefn('highway', ogr.OFTString)
118 | surface_true_field = ogr.FieldDefn('surface_t', ogr.OFTString)
119 | surface_pred_field = ogr.FieldDefn('surface_p', ogr.OFTString)
120 | paved_conf = ogr.FieldDefn('paved_c', ogr.OFSTFloat32)
121 | unpaved_conf = ogr.FieldDefn('unpaved_c', ogr.OFSTFloat32)
122 | obsc_conf = ogr.FieldDefn('obsc_c', ogr.OFSTFloat32)
123 |
124 | layer.CreateField(osm_id_field)
125 | layer.CreateField(highway_field)
126 | layer.CreateField(surface_true_field)
127 | layer.CreateField(surface_pred_field)
128 | layer.CreateField(paved_conf)
129 | layer.CreateField(unpaved_conf)
130 | layer.CreateField(obsc_conf)
131 |
132 | feature_defn = layer.GetLayerDefn()
133 |
134 | for osm_id, row in df.iterrows():
135 |
136 | poly = ogr.CreateGeometryFromWkt(row['wkt'])
137 |
138 | feat = ogr.Feature(feature_defn)
139 |
140 | feat.SetGeometry(poly)
141 | feat.SetField('osm_id', row.name)
142 | feat.SetField('highway', row['highway_tag'])
143 | feat.SetField('surface_t', row['surface_tag'])
144 | feat.SetField('surface_p', row['pred_label'])
145 | feat.SetField('paved_c', row['pred_paved'])
146 | feat.SetField('unpaved_c', row['pred_unpaved'])
147 | feat.SetField('obsc_c', row['pred_Obscured'])
148 |
149 | layer.CreateFeature(feat)
150 | poly = None
151 | feat = None
152 |
153 | layer = None # type: ignore
154 | ds = None # type: ignore
155 |
156 | #%% Convert the CSV file into another QGIS GPKG, but pruned to ways with
157 | # known surface types, such that we can evaluate the model's accuracy easier
158 | import sqlite3
159 | import pandas as pd
160 | from osgeo import gdal, ogr, osr
161 |
162 | gdal.UseExceptions()
163 | ogr.UseExceptions()
164 |
165 | # Read CSV, merge with dataset, extract the labels we want
166 | df = pd.read_csv(
167 | f'/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019_results_{results_name}.csv'
168 | ).set_index('osm_id')
169 | with sqlite3.connect(
170 | 'file:/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019.sqlite3?mode=ro',
171 | uri=True) as con:
172 | df2 = pd.read_sql('SELECT * FROM features;', con).set_index('osm_id')
173 | df = df.join(df2)
174 | df = df[['wkt', 'pred_label', 'surface_tag', 'highway_tag', 'pred_Obscured']]
175 |
176 | # Mapping of OSM surface types to simple paved / unpaved
177 | class_map = {
178 | 'asphalt': 'paved',
179 | 'bricks': 'paved',
180 | 'compacted': 'unpaved',
181 | 'concrete': 'paved',
182 | 'concrete:plates': 'paved',
183 | 'dirt': 'unpaved',
184 | 'gravel': 'unpaved',
185 | 'ground': 'unpaved',
186 | 'paved': 'paved',
187 | 'paving_stones': 'paved',
188 | 'unpaved': 'unpaved',
189 | }
190 |
191 | # Trim dataset + determine "correctness"
192 | df = df[df['surface_tag'] != '']
193 | df['surface_tag'] = df['surface_tag'].apply(class_map.get)
194 | df['correct'] = df['surface_tag'] == df['pred_label']
195 |
196 | # Create SRS (EPSG:4326: WGS-84 decimal degrees)
197 | srs = osr.SpatialReference()
198 | srs.ImportFromEPSG(4326)
199 |
200 | # Put into GPKG format for QGIS
201 | driver: ogr.Driver = ogr.GetDriverByName('GPKG')
202 | ds: ogr.DataSource = driver.CreateDataSource(
203 | f'/data/road_surface_classifier/BOULDER_COUNTY_NAIP_2019_results_eval_{results_name}.gpkg'
204 | )
205 | layer: ogr.Layer = ds.CreateLayer('data', srs=srs, geom_type=ogr.wkbLineString)
206 |
207 | osm_id_field = ogr.FieldDefn('osm_id', ogr.OFTInteger64)
208 | highway_field = ogr.FieldDefn('highway', ogr.OFTString)
209 | surface_true_field = ogr.FieldDefn('surface_t', ogr.OFTString)
210 | surface_pred_field = ogr.FieldDefn('surface_p', ogr.OFTString)
211 | correct_field = ogr.FieldDefn('correct', ogr.OFTString)
212 | obsc_field = ogr.FieldDefn('obsc', ogr.OFTReal)
213 |
214 | layer.CreateField(osm_id_field)
215 | layer.CreateField(highway_field)
216 | layer.CreateField(surface_true_field)
217 | layer.CreateField(surface_pred_field)
218 | layer.CreateField(correct_field)
219 | layer.CreateField(obsc_field)
220 |
221 | feature_defn = layer.GetLayerDefn()
222 |
223 | for osm_id, row in df.iterrows():
224 |
225 | poly = ogr.CreateGeometryFromWkt(row['wkt'])
226 |
227 | feat = ogr.Feature(feature_defn)
228 |
229 | feat.SetGeometry(poly)
230 | feat.SetField('osm_id', row.name)
231 | feat.SetField('highway', row['highway_tag'])
232 | feat.SetField('surface_t', row['surface_tag'])
233 | feat.SetField('surface_p', row['pred_label'])
234 | feat.SetField('correct', str(row['correct']))
235 | feat.SetField('obsc', row['pred_Obscured'])
236 |
237 | layer.CreateFeature(feat)
238 | poly = None
239 | feat = None
240 |
241 | layer = None # type: ignore
242 | ds = None # type: ignore
--------------------------------------------------------------------------------
/scripts/parse_features.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """ Helper script used to inspect all available surface types within our dataset """
4 | import pickle
5 | import pathlib
6 |
7 | from tqdm import tqdm
8 | from osgeo import gdal
9 |
10 | if __name__ == '__main__':
11 |
12 | feature_data = []
13 | feature_data_pickle_path = pathlib.Path(
14 | '/data/road_surface_classifier/feature_data.pkl')
15 |
16 | if not feature_data_pickle_path.exists():
17 | # Load road surface data from file
18 | ds = gdal.OpenEx(
19 | '/data/gis/us_road_surface/us_w_road_surface_filtered.gpkg')
20 | layer = ds.GetLayer()
21 | feature_count = layer.GetFeatureCount()
22 |
23 | # Extract features we care about
24 | print('Extracting features...')
25 | for idx in tqdm(range(feature_count)):
26 | feat = layer.GetNextFeature()
27 | wkt_str = feat.GetGeometryRef().ExportToWkt()
28 | osm_id = feat.GetField(0)
29 | highway = feat.GetField(1)
30 | surface = feat.GetField(2)
31 | feature_data.append((osm_id, wkt_str, highway, surface))
32 | layer = None
33 | ds = None
34 |
35 | # Pickle the data so we don't have to process again
36 | with open(feature_data_pickle_path, 'wb') as f:
37 | pickle.dump(feature_data, f)
38 |
39 | else:
40 | # Load the data from file
41 | with open(feature_data_pickle_path, 'rb') as f:
42 | feature_data = pickle.load(f)
43 |
44 | # Some data cleanup
45 | # Get all surface types with more than 1000 labels
46 | surface_types = list(set([e[3].lower() for e in feature_data]))
47 | count_dict = {k: 0 for k in surface_types}
48 | for _, _, _, surface in feature_data:
49 | count_dict[surface.lower()] += 1
50 | del_keys = [k for k, v in count_dict.items() if v < 1000]
51 | [count_dict.pop(k) for k in del_keys]
--------------------------------------------------------------------------------
/scripts/perform_query.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """ OSM Overpass query to find all drivable roads with a labeled 'surface' tag """
4 | from __future__ import annotations
5 | import pathlib
6 |
7 | from rsc.osm.overpass_api import OSMOverpassQuery, OSMOverpassResult
8 |
9 |
10 | class OSMCustomOverpassQuery(OSMOverpassQuery):
11 | """ Custom OSM Overpass API query for (hopefully) drivable road networks """
12 |
13 | __slots__ = ['_highway_tags']
14 |
15 | DEFAULT_HIGHWAY_TAGS = [
16 | 'motorway', 'motorway_link', 'motorway_junction', 'trunk',
17 | 'trunk_link', 'primary', 'primary_link', 'secondary', 'secondary_link',
18 | 'tertiary', 'tertiary_link', 'unclassified', 'residential'
19 | ]
20 |
21 | def __init__(self, **kwargs):
22 | super().__init__(**kwargs)
23 | self._highway_tags = kwargs.get('highway_tags',
24 | self.DEFAULT_HIGHWAY_TAGS)
25 |
26 | def perform_query(self) -> OSMCustomOverpassResult:
27 | """ Perform an OSM Overpass API Request! """
28 | return OSMCustomOverpassResult(self._perform_query())
29 |
30 | @property
31 | def _query_str(self) -> str:
32 | # NOTE: OSMNX uses instead
33 | # ["highway"!~"abandoned|bridleway|bus_guideway|construction|corridor|cycleway|elevator|escalator|footway|path|pedestrian|planned|platform|proposed|raceway|service|steps|track"]
34 |
35 | return f"""
36 | [out:{self._format}]
37 | [timeout:{self._timeout}]
38 | [maxsize:2147483648];
39 | (way["highway"]
40 | ["area"!~"yes"]
41 | ["access"!~"private"]
42 | ["highway"~"{'|'.join(self._highway_tags)}"]
43 | ["motor_vehicle"!~"no"]
44 | ["motorcar"!~"no"]
45 | ["surface"!~""]
46 | ["service"!~"alley|driveway|emergency_access|parking|parking_aisle|private"]
47 | (poly:'{self._poly_query_str}');
48 | >;
49 | );
50 | out;
51 | """
52 |
53 |
54 | class OSMCustomOverpassResult(OSMOverpassResult):
55 | """ Container class for OSM Overpass result data for (hopefully) drivable road networks """
56 |
57 | def to_file(self) -> None:
58 | """ Convert result data to an OSM network """
59 | tmp_file = pathlib.Path('/data/gis/result')
60 | this_format = self.format
61 | if 'json' in this_format:
62 | tmp_file = tmp_file.with_suffix('.json')
63 | elif 'xml' in this_format:
64 | tmp_file = tmp_file.with_suffix('.osm')
65 | else:
66 | raise ValueError(
67 | 'Unknown format: \'%s\'! Cannot convert to OSM network.')
68 |
69 | # Save the response content to file
70 | tmp_file.write_bytes(self._result.content)
71 |
72 |
73 | if __name__ == '__main__':
74 |
75 | # Setup custom query to local interpreter
76 | q = OSMCustomOverpassQuery(format='xml', timeout=24 * 60 * 60)
77 | q.set_endpoint('http://localhost:12345/api/interpreter')
78 |
79 | # Use rough USA bounds for query
80 | with open('gis/us_wkt.txt', 'r') as f:
81 | us_wkt = f.read()
82 | q.set_poly_from_wkt(us_wkt)
83 |
84 | # Perform query and save! This will take a long time.
85 | print('Performing query...')
86 | result = q.perform_query()
87 | print('Saving to file...')
88 | result.to_file()
--------------------------------------------------------------------------------
/scripts/pre_inference.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """ Pre-mass inference script. Loop through images and find chips of drivable ways. """
4 | #%% Imports and inputs
5 | import sqlite3
6 |
7 | import pathlib
8 | from osgeo import gdal, ogr, osr
9 | import pandas as pd
10 |
11 | from rsc.common.utils import imread_geometry, imread_dims, map_to_pix
12 | from rsc.osm.overpass_api.road_network import OSMRoadNetworkOverpassQuery
13 |
14 | gdal.UseExceptions()
15 | ogr.UseExceptions()
16 |
17 | OVERPASS_INTERPRETER_URL = "http://localhost:12345/api/interpreter"
18 | TEST_CASE_PATH = pathlib.Path("/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019")
19 | assert TEST_CASE_PATH.is_dir()
20 | OUTPUT_PATH = pathlib.Path(
21 | '/nfs/taranis/naip/BOULDER_COUNTY_NAIP_2019.sqlite3')
22 | assert not OUTPUT_PATH.exists()
23 |
24 | # Setup custom query to local interpreter
25 | q = OSMRoadNetworkOverpassQuery(format='xml', timeout=24 * 60 * 60)
26 | q.set_endpoint('http://localhost:12345/api/interpreter')
27 |
28 | # Get EPSG:4326 SRS
29 | srs = osr.SpatialReference()
30 | srs.ImportFromEPSG(4326)
31 | srs.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
32 |
33 | #%% Loop through imagery and find drivable ways
34 |
35 | # Rows to convert to Pandas dataframe
36 | df_rows = []
37 |
38 | for img_path in TEST_CASE_PATH.glob('*.mrf'):
39 | print('Processing %s...' % img_path.name)
40 |
41 | # Get geometry for image
42 | img_geom = imread_geometry(img_path, wgs84=True)
43 | h, w = imread_dims(img_path)
44 |
45 | # Perform overpass query for geometry
46 | q.set_poly_from_wkt(img_geom)
47 |
48 | # Perform the query
49 | result = q.perform_query()
50 |
51 | # Convert to road network
52 | network = result.to_network()
53 |
54 | # If no roads show up, skip for now
55 | if not network.num_ways:
56 | continue
57 |
58 | # Compute image-specific coordinate transformation
59 | ds: gdal.Dataset = gdal.Open(str(img_path), gdal.GA_ReadOnly)
60 | im_h, im_w = ds.RasterYSize, ds.RasterXSize
61 | g_xform = ds.GetGeoTransform()
62 | srs_ds: osr.SpatialReference = ds.GetSpatialRef()
63 | srs_ds.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
64 | ds = None # type: ignore
65 | c_xform = osr.CoordinateTransformation(srs, srs_ds)
66 |
67 | # Loop through all available ways
68 | for way in network.get_ways():
69 | osm_id = way.id
70 |
71 | # Get nodes
72 | nodes = [network.get_node_by_id(id)
73 | for id in way.nodes] # type: ignore
74 |
75 | # Get linestring geometry
76 | linestr = ogr.Geometry(ogr.wkbLineString)
77 | [linestr.AddPoint_2D(node.lon, node.lat) for node in nodes]
78 | wkt = linestr.ExportToWkt()
79 |
80 | # Get midpoint lon, lat
81 | m_lon, m_lat = linestr.GetPoint_2D(linestr.GetPointCount() // 2)
82 |
83 | # Compute center point in image spatial coordinate system
84 | pt = ogr.Geometry(ogr.wkbPoint)
85 | pt.AddPoint_2D(m_lon, m_lat)
86 | pt.Transform(c_xform)
87 |
88 | # Compute upper-left corner for image chip
89 | x, y = [e[0].item() for e in map_to_pix(g_xform, pt.GetX(), pt.GetY())]
90 | if (x < 0 or x >= im_w) or (y < 0 or y >= im_h):
91 | # Center of way off of image, skipping.
92 | # NOTE: this OSM_ID may appear in another image
93 | continue
94 | x1, y1 = x - 128, y - 128
95 | x2, y2 = x1 + 256, y1 + 256
96 |
97 | # Handle boundaries
98 | if x1 < 0:
99 | x1, x2 = 0, 256
100 | if y1 < 0:
101 | y1, y2 = 0, 256
102 | if x2 >= w:
103 | x1, x2 = w - 256 - 1, w - 1
104 | if y2 >= h:
105 | y1, y2 = h - 256 - 1, h - 1
106 |
107 | # Fetch tags
108 | highway_tag = way.tags.get('highway', '')
109 | surface_tag = way.tags.get('surface', '')
110 |
111 | df_rows.append([
112 | osm_id, img_path.name, wkt, m_lon, m_lat, x1, y1, x2, y2,
113 | highway_tag, surface_tag
114 | ])
115 |
116 | #%% Export data to SQLite3 database
117 |
118 | df = pd.DataFrame(df_rows,
119 | columns=[
120 | 'osm_id', 'img', 'wkt', 'm_lon', 'm_lat', 'x1', 'y1',
121 | 'x2', 'y2', 'highway_tag', 'surface_tag'
122 | ]).set_index('osm_id')
123 | with sqlite3.connect(OUTPUT_PATH) as con:
124 | df.to_sql('features', con)
125 |
--------------------------------------------------------------------------------
/scripts/run_overpass_api.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # https://hub.docker.com/r/wiktorn/overpass-api
4 | # http://localhost:12345/api/interpreter
5 | docker run \
6 | -e OVERPASS_META=yes \
7 | -e OVERPASS_MODE=init \
8 | -e OVERPASS_PLANET_URL=file:///data/gis/us-latest.osm.bz2 \
9 | -e OVERPASS_RULES_LOAD=10 \
10 | -e OVERPASS_SPACE=55000000000 \
11 | -e OVERPASS_MAX_TIMEOUT=86400 \
12 | -v /data/gis:/data/gis \
13 | -v /data/gis/overpass_db:/db \
14 | -p 12345:80 \
15 | -i --name overpass_usa wiktorn/overpass-api:latest
16 |
--------------------------------------------------------------------------------
/scripts/sample_augmentations.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 |
7 | from torch.utils.data import DataLoader
8 |
9 |
10 | def sample_augmentations(ds, transform, grid=5):
11 |
12 | dl = DataLoader(ds, num_workers=1, batch_size=1, shuffle=True)
13 | for x, _ in dl:
14 | break
15 | x = x[0, ...] # type: ignore
16 | x_np = x.numpy()
17 | im_np = np.moveaxis(x_np[0:4, :, :] * 255., 0, -1).astype(np.uint8)
18 | mask_np = np.moveaxis(x_np[4:5, :, :] * 255., 0, -1).astype(np.uint8)
19 | pmask_np = np.moveaxis(x_np[5:6, :, :] * 255., 0, -1).astype(np.uint8)
20 |
21 | fig, ax = plt.subplots(grid,
22 | 4,
23 | sharex=True,
24 | sharey=True,
25 | figsize=(3 * 4, 3 * grid))
26 | ax = ax.flatten() # type: ignore
27 |
28 | ax[0].imshow(im_np[..., (0, 1, 2)]) # type: ignore
29 | ax[1].imshow(im_np[..., (3, 0, 1)]) # type: ignore
30 | ax[2].imshow(mask_np)
31 | ax[3].imshow(pmask_np)
32 |
33 | # Plot images
34 | for idx in range(4, len(ax), 4):
35 | # Do an augmentation
36 | im_aug, pm_aug = transform(x)
37 | m_aug = im_aug[:, 4:5, ...]
38 | im_aug = im_aug[:, 0:4, ...]
39 |
40 | im_aug_np = (np.moveaxis(im_aug[0, ...].numpy(), 0, -1) * 255.).astype(
41 | np.uint8)
42 | m_aug_np = np.moveaxis(m_aug[0, ...].numpy() * 255., 0,
43 | -1).astype(np.uint8)
44 | pm_aug_np = np.moveaxis(pm_aug[0, ...].numpy() * 255., 0,
45 | -1).astype(np.uint8)
46 |
47 | # Plot it
48 | ax[idx].imshow(im_aug_np[..., (0, 1, 2)]) # type: ignore
49 | ax[idx + 1].imshow(im_aug_np[..., (3, 0, 1)]) # type: ignore
50 | ax[idx + 2].imshow(m_aug_np)
51 | ax[idx + 3].imshow(pm_aug_np)
52 |
53 | for _ax in ax:
54 | _ax.get_xaxis().set_visible(False)
55 | _ax.get_yaxis().set_visible(False)
56 | fig.tight_layout()
57 |
58 | plt.show()
59 |
60 |
61 | if __name__ == '__main__':
62 |
63 | from rsc.train.dataset import RoadSurfaceDataset
64 | from rsc.train.preprocess import PreProcess
65 | from rsc.train.data_augmentation import DataAugmentation
66 |
67 | train_ds = RoadSurfaceDataset(
68 | '/data/road_surface_classifier/dataset/dataset_train.csv',
69 | transform=PreProcess(),
70 | limit=50)
71 |
72 | sample_augmentations(train_ds, DataAugmentation())
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import sys
6 | from typing import Optional
7 |
8 | sys.path.append(
9 | os.getcwd()) # TODO: this is silly, would be fixed with pip install
10 |
11 | # Set AWS profile (for use in MLFlow)
12 | os.environ["AWS_PROFILE"] = 'truenas'
13 | os.environ["MLFLOW_S3_ENDPOINT_URL"] = 'http://truenas.local:9807'
14 |
15 | import pathlib
16 | from datetime import datetime
17 |
18 | import pandas as pd
19 |
20 | import torch
21 | from torch.utils.data import DataLoader
22 | import pytorch_lightning as pl
23 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, StochasticWeightAveraging
24 | from pytorch_lightning.loggers import MLFlowLogger
25 |
26 | from rsc.train.plmcnn import PLMaskCNN
27 | from rsc.train.preprocess import PreProcess
28 | from rsc.train.dataset import RoadSurfaceDataset
29 | from rsc.artifacts.confusion_matrix_handler import ConfusionMatrixHandler
30 |
31 | # Recommended for CUDA
32 | torch.set_float32_matmul_precision('medium')
33 |
34 | # Whether or not to include the NIR channel
35 | # since the input dataset includes it
36 | INCLUDE_NIR = True
37 |
38 | # Quick test, train an epic on fewer-than-normal
39 | # set of images to check everything works
40 | QUICK_TEST = False
41 |
42 | if __name__ == '__main__':
43 |
44 | # Create directory for results
45 | results_dir = pathlib.Path(
46 | '/data/road_surface_classifier/results').resolve()
47 | assert results_dir.is_dir()
48 | now = datetime.utcnow().strftime(
49 | '%Y%m%d_%H%M%SZ') # timestamp string used for traceability
50 | save_dir = results_dir / now
51 | save_dir.mkdir(parents=False, exist_ok=False)
52 |
53 | # Log epoch + validation loss to CSV
54 | #logger = CSVLogger(str(save_dir))
55 |
56 | # Labels and weights
57 | weights_df = pd.read_csv(
58 | '/data/road_surface_classifier/dataset_multiclass/class_weights.csv')
59 | labels = list(weights_df['class_name'])
60 | top_level_map = list(weights_df['top_level'])
61 | class_weights = list(weights_df['weight'])
62 |
63 | # Get dataset
64 | chip_size=224
65 | preprocess = PreProcess()
66 | train_ds = RoadSurfaceDataset(
67 | '/data/road_surface_classifier/dataset_multiclass/dataset_train.csv',
68 | transform=preprocess,
69 | chip_size=chip_size,
70 | limit=-1 if not QUICK_TEST else 500,
71 | n_channels=4 if INCLUDE_NIR else 3)
72 | val_ds = RoadSurfaceDataset(
73 | '/data/road_surface_classifier/dataset_multiclass/dataset_val.csv',
74 | transform=preprocess,
75 | chip_size=chip_size,
76 | limit=-1 if not QUICK_TEST else 500,
77 | n_channels=4 if INCLUDE_NIR else 3)
78 |
79 | # Create data loaders.
80 | batch_size = 64
81 | train_dl = DataLoader(train_ds,
82 | num_workers=16,
83 | batch_size=batch_size,
84 | shuffle=True)
85 | val_dl = DataLoader(val_ds, num_workers=16, batch_size=batch_size)
86 |
87 | # Model
88 | learning_rate=0.00040984645874638675
89 | model = PLMaskCNN(weights=class_weights,
90 | labels=labels,
91 | nc=4 if INCLUDE_NIR else 3,
92 | top_level_map=top_level_map,
93 | learning_rate=learning_rate,
94 | seg_k=0.7,
95 | ob_k=0.9)
96 |
97 |
98 | # Save model to results directory
99 | torch.save(model, save_dir / 'model.pth')
100 |
101 | # Attempt deserialization (b/c) I've had problems with it before
102 | try:
103 | torch.load(save_dir / 'model.pth')
104 | except:
105 | import traceback
106 | traceback.print_exc()
107 | raise AssertionError('Torch model failed to deserialize!')
108 |
109 | # Train model in stages
110 | best_model_path: Optional[str] = None
111 | mlflow_logger: Optional[MLFlowLogger] = None
112 | stage = 0
113 | model.set_stage(stage, learning_rate)
114 |
115 |
116 | # Logger
117 | mlflow_logger = MLFlowLogger(experiment_name='road_surface_classifier',
118 | run_name='run_%s_%d' % (now, stage),
119 | tracking_uri='http://truenas.local:9809')
120 |
121 | # Upload base model
122 | mlflow_logger.experiment.log_artifact(mlflow_logger.run_id,
123 | str(save_dir / 'model.pth'))
124 |
125 | # Save checkpoints (model states for later)
126 | checkpoint_callback = ModelCheckpoint(
127 | dirpath=str(save_dir),
128 | monitor='val_loss',
129 | save_top_k=3,
130 | filename='model-%d-{epoch:02d}-{val_loss:.5f}' % stage)
131 |
132 | # Setup early stopping based on validation loss
133 | early_stopping_callback = EarlyStopping(monitor='val_loss',
134 | mode='min',
135 | patience=10)
136 |
137 | # Stochastic Weight Averaging
138 | swa_callback = StochasticWeightAveraging(swa_lrs=0.6863423660749621)
139 |
140 | # Trainer
141 | trainer = pl.Trainer(accelerator='gpu',
142 | devices=1,
143 | max_epochs=1000 if not QUICK_TEST else 1,
144 | callbacks=[
145 | checkpoint_callback, early_stopping_callback,
146 | swa_callback
147 | ],
148 | logger=mlflow_logger)
149 |
150 | # Do the thing!
151 | trainer.fit(model, train_dataloaders=train_dl,
152 | val_dataloaders=val_dl) # type: ignore
153 |
154 | # Get best model path
155 | best_model_path = checkpoint_callback.best_model_path
156 |
157 | # Upload best model
158 | mlflow_logger.experiment.log_artifact(mlflow_logger.run_id,
159 | best_model_path)
160 |
161 | # Load model at best checkpoint for next stage
162 | del model # just to be safe
163 | model = PLMaskCNN.load_from_checkpoint(best_model_path)
164 |
165 | assert best_model_path is not None
166 | assert mlflow_logger is not None
167 |
168 | # Generate artifacts
169 | print('Generating artifacts...')
170 | model.eval()
171 |
172 | # Generate artifacts from model
173 | from rsc.artifacts import ArtifactGenerator
174 | from rsc.artifacts.confusion_matrix_handler import ConfusionMatrixHandler
175 | from rsc.artifacts.accuracy_obsc_handler import AccuracyObscHandler
176 | from rsc.artifacts.obsc_compare_handler import ObscCompareHandler
177 | from rsc.artifacts.samples_handler import SamplesHandler
178 | artifacts_dir = save_dir / 'artifacts'
179 | artifacts_dir.mkdir(parents=False, exist_ok=True)
180 | generator = ArtifactGenerator(artifacts_dir, model, val_dl)
181 | generator.add_handler(ConfusionMatrixHandler(simple=False))
182 | generator.add_handler(ConfusionMatrixHandler(simple=True))
183 | generator.add_handler(AccuracyObscHandler())
184 | generator.add_handler(ObscCompareHandler())
185 | generator.add_handler(SamplesHandler())
186 | generator.run(raise_on_error=False)
187 | mlflow_logger.experiment.log_artifacts(mlflow_logger.run_id,
188 | str(artifacts_dir))
189 |
--------------------------------------------------------------------------------
/scripts/train_optuna.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import sys
6 | import traceback
7 | from typing import Optional
8 |
9 | sys.path.append(
10 | os.getcwd()) # TODO: this is silly, would be fixed with pip install
11 |
12 | # Set AWS profile (for use in MLFlow)
13 | os.environ["AWS_PROFILE"] = 'truenas'
14 | os.environ["MLFLOW_S3_ENDPOINT_URL"] = 'http://truenas:9807'
15 |
16 | import pathlib
17 | from datetime import datetime
18 |
19 | import pandas as pd
20 |
21 | import torch
22 | from torch.utils.data import DataLoader
23 | import pytorch_lightning as pl
24 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, StochasticWeightAveraging
25 | from pytorch_lightning.loggers import MLFlowLogger
26 |
27 | import optuna
28 |
29 | from rsc.train.plmcnn import PLMaskCNN
30 | from rsc.train.preprocess import PreProcess
31 | from rsc.train.dataset import RoadSurfaceDataset
32 |
33 | # Recommended for CUDA
34 | torch.set_float32_matmul_precision('medium')
35 |
36 | # Whether or not to include the NIR channel
37 | # since the input dataset includes it
38 | INCLUDE_NIR = True
39 |
40 | # Quick test, train an epic on fewer-than-normal
41 | # set of images to check everything works
42 | QUICK_TEST = False
43 |
44 | if __name__ == '__main__':
45 |
46 | # Optuna config
47 | num_trials = 50
48 | now = datetime.utcnow().strftime(
49 | '%Y%m%d_%H%M%SZ') # timestamp string used for traceability
50 | study_name = 'study_%s' % now
51 |
52 | # Labels and weights
53 | weights_df = pd.read_csv(
54 | '/data/road_surface_classifier/dataset_multiclass/class_weights.csv')
55 | labels = list(weights_df['class_name'])
56 | top_level_map = list(weights_df['top_level'])
57 | class_weights = list(weights_df['weight'])
58 |
59 | # Get dataset
60 | preprocess = PreProcess()
61 | train_ds = RoadSurfaceDataset(
62 | '/data/road_surface_classifier/dataset_multiclass/dataset_train.csv',
63 | transform=preprocess,
64 | chip_size=224,
65 | limit=-1 if not QUICK_TEST else 1500,
66 | n_channels=4 if INCLUDE_NIR else 3)
67 | val_ds = RoadSurfaceDataset(
68 | '/data/road_surface_classifier/dataset_multiclass/dataset_val.csv',
69 | transform=preprocess,
70 | chip_size=224,
71 | limit=-1 if not QUICK_TEST else 500,
72 | n_channels=4 if INCLUDE_NIR else 3)
73 |
74 | def objective(trial: optuna.trial.Trial):
75 | global train_ds, val_ds
76 |
77 | # Create data loaders.
78 | batch_size = 64
79 | train_dl = DataLoader(train_ds,
80 | num_workers=16,
81 | batch_size=batch_size,
82 | shuffle=True)
83 | val_dl = DataLoader(val_ds, num_workers=16, batch_size=batch_size)
84 |
85 | # Hyperparameters
86 | chip_size = 224
87 | learning_rate = trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True)
88 | swa_lrs = trial.suggest_float("swa_lrs", 1e-4, 1e0, log=True)
89 | seg_k = trial.suggest_float("seg_k", 0.1, 1, step=0.1)
90 | ob_k = trial.suggest_float("ob_k", 0.1, 1, step=0.1)
91 |
92 | # Set dataset parameters
93 | train_ds.set_chip_size(chip_size)
94 | val_ds.set_chip_size(chip_size)
95 |
96 | # Set model parameters
97 | model = PLMaskCNN(weights=class_weights,
98 | labels=labels,
99 | nc=4 if INCLUDE_NIR else 3,
100 | top_level_map=top_level_map,
101 | learning_rate=learning_rate,
102 | seg_k=seg_k,
103 | ob_k=ob_k)
104 |
105 | # Create directory for results
106 | results_dir = pathlib.Path(
107 | '/data/road_surface_classifier/results').resolve()
108 | assert results_dir.is_dir()
109 | now = datetime.utcnow().strftime(
110 | '%Y%m%d_%H%M%SZ') # timestamp string used for traceability
111 | save_dir = results_dir / now
112 | save_dir.mkdir(parents=False, exist_ok=False)
113 |
114 | # Save model to results directory
115 | torch.save(model, save_dir / 'model.pth')
116 |
117 | # Attempt deserialization (b/c) I've had problems with it before
118 | try:
119 | torch.load(save_dir / 'model.pth')
120 | except:
121 | traceback.print_exc()
122 | raise AssertionError('Torch model failed to deserialize!')
123 |
124 | # Train model in stages
125 | best_model_path: Optional[str] = None
126 | mlflow_logger: Optional[MLFlowLogger] = None
127 | stage = 0
128 |
129 | # Logger
130 | mlflow_logger = MLFlowLogger(
131 | experiment_name='road_surface_classifier',
132 | run_name='run_%s_%d_trial_%d' % (now, stage, trial.number),
133 | tracking_uri='http://truenas:9809')
134 | mlflow_logger.log_hyperparams(dict(chip_size=chip_size, swa_lrs=swa_lrs))
135 |
136 | # Upload base model
137 | mlflow_logger.experiment.log_artifact(mlflow_logger.run_id,
138 | str(save_dir / 'model.pth'))
139 |
140 | # Save checkpoints (model states for later)
141 | checkpoint_callback = ModelCheckpoint(
142 | dirpath=str(save_dir),
143 | monitor='val_loss',
144 | save_top_k=1,
145 | filename='model-%d-{epoch:02d}-{val_loss:.5f}' % stage)
146 |
147 | # Setup early stopping based on validation loss
148 | early_stopping_callback = EarlyStopping(monitor='val_loss',
149 | mode='min',
150 | patience=10)
151 |
152 | # Stochastic Weight Averaging
153 | swa_callback = StochasticWeightAveraging(swa_lrs=swa_lrs)
154 |
155 | # Trainer
156 | trainer = pl.Trainer(
157 | accelerator='gpu',
158 | devices=1,
159 | max_epochs=300,
160 | callbacks=[checkpoint_callback, early_stopping_callback, swa_callback],
161 | logger=mlflow_logger)
162 |
163 | # Do the thing!
164 | trainer.fit(model, train_dataloaders=train_dl,
165 | val_dataloaders=val_dl) # type: ignore
166 |
167 | # Get best model path
168 | best_model_path = checkpoint_callback.best_model_path
169 |
170 | # Upload best model
171 | mlflow_logger.experiment.log_artifact(mlflow_logger.run_id,
172 | best_model_path)
173 |
174 | # Objective return values
175 | ret = (model.min_val_loss_cl, model.min_val_loss_ob)
176 |
177 | # Load model at best checkpoint to generate artifacts
178 | del model # just to be safe
179 | model = PLMaskCNN.load_from_checkpoint(best_model_path)
180 |
181 | assert best_model_path is not None
182 | assert mlflow_logger is not None
183 |
184 | # Generate artifacts
185 | print('Generating artifacts...')
186 | model.eval()
187 |
188 | # Generate artifacts from model
189 | from rsc.artifacts import ArtifactGenerator
190 | from rsc.artifacts.confusion_matrix_handler import ConfusionMatrixHandler
191 | from rsc.artifacts.accuracy_obsc_handler import AccuracyObscHandler
192 | from rsc.artifacts.obsc_compare_handler import ObscCompareHandler
193 | artifacts_dir = save_dir / 'artifacts'
194 | artifacts_dir.mkdir(parents=False, exist_ok=True)
195 | generator = ArtifactGenerator(artifacts_dir, model, val_dl)
196 | generator.add_handler(ConfusionMatrixHandler(simple=False))
197 | generator.add_handler(ConfusionMatrixHandler(simple=True))
198 | generator.add_handler(AccuracyObscHandler())
199 | generator.add_handler(ObscCompareHandler())
200 | generator.run(raise_on_error=False)
201 | mlflow_logger.experiment.log_artifacts(mlflow_logger.run_id,
202 | str(artifacts_dir))
203 |
204 | return ret
205 |
206 | # Do the work!
207 | for _ in range(3):
208 | try:
209 | study = optuna.create_study(
210 | directions=['minimize', 'minimize'],
211 | storage='sqlite:///./optuna_rsc.sqlite3',
212 | study_name='study_20240108_000000Z',
213 | load_if_exists=True)
214 | study.optimize(objective, n_trials=num_trials)
215 | except Exception:
216 | traceback.print_exc()
217 | print('Restarting study...')
218 | else:
219 | break
220 |
221 | print('Done!')
--------------------------------------------------------------------------------
|